Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ztxz16/fastllm
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 9, 2024
2 parents 6081a40 + dedaee4 commit f25be77
Show file tree
Hide file tree
Showing 6 changed files with 600 additions and 0 deletions.
56 changes: 56 additions & 0 deletions example/openai_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# OpenAI Compatible API Server
## 介绍
这是一个兼容OpenAI API 的HTTP Server, 支持加载fastllm的flm模型, 暂时不支持https方式访问
> 这个实现参考了vllm v0.4.1 中OpenAI-compatible API server的实现, 在这个基础上进行了简化

## 目前支持的接口
* Open AI 接口官方说明: https://platform.openai.com/docs/api-reference/

| 类型 | 接口名称 | method | 目前明确不支持的选项| 官方说明和样例 |
| :--- | :--------------------- | :------------------ | :--------------------------------------------------------- |:--------------------------------------------------------- |
| Chat | Create chat completion | v1/chat/completions | (n, presence_penalty, tools, functions, logprobs, seed, logit_bias) | https://platform.openai.com/docs/api-reference/chat/create |


> 考虑到Completions接口已经被标记为Legacy接口,因此不实现该接口
## 依赖
以下依赖在python 3.12.2上没有问题
1. 需要先安装fastllm_pytools工具包
2. 需要安装以下依赖
```bash
cd example/openai_server
pip install -r requirements.txt
```

## 使用方法 && 样例
* server启动命令
```bash
cd example/openai_server
python openai_api_server.py --model_name "model_name" -p "path_to_your_flm_model"
# eg : python openai_api_server.py --model_name "chat-glm2-6b-int4" -p "./chatglm2-6b-int4.flm"
```

* client测试命令
```bash
# client 测试
# 测试命令
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer something" \
-d '{
"model": "chat-glm2-6b-int4",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}'] }
# 响应结果
{"id":"fastllm-chat-glm2-6b-int4-e4fd6bea564548f6ae95f6327218616d","object":"chat.completion","created":1715150460,"model":"chat-glm2-6b-int4","choices":[{"index":0,"message":{"role":"assistant","content":" Hello! How can I assist you today?"},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"total_tokens":0,"completion_tokens":0}}
```


269 changes: 269 additions & 0 deletions example/openai_server/fastllm_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import logging
import json
import traceback
from fastapi import Request
from http import HTTPStatus
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
import uuid
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)

from protocal.openai_protocol import *
from fastllm_pytools import llm

class ConversationMessage:
def __init__(self, role:str, content:str):
self.role = role
self.content = content

def random_uuid() -> str:
return str(uuid.uuid4().hex)

class FastLLmCompletion:
def __init__(self,
model_name,
model_path,
low_mem_mode = False,
cpu_thds = 16):
self.model_name = model_name
self.model_path = model_path
self.low_mem_mode = low_mem_mode
self.cpu_thds = cpu_thds
self.init_fast_llm_model()

def init_fast_llm_model(self):
llm.set_cpu_threads(self.cpu_thds)
llm.set_cpu_low_mem(self.low_mem_mode)
self.model = llm.model(self.model_path)

def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)

async def _check_model(self, request: ChatCompletionRequest):
if request.model != self.model_name:
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return None

def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []

# 暂时不支持图像输入
raise NotImplementedError("Complex input not supported yet")

async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

query:str = ""
history:List[Tuple[str, str]] = []
try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])

conversation.extend(messages)

# fastllm 样例中history只能是一问一答, system promt 暂时不支持
if len(conversation) == 0:
raise Exception("Empty msg")

for i in range(len(conversation)):
msg = conversation[i]
if msg.role == "system":
# fastllm 暂时不支持system prompt
continue
elif msg.role == "user":
if i + 1 < len(conversation):
next_msg = conversation[i + 1]
if next_msg.role == "assistant":
history.append((msg.content, next_msg.content))
else:
# 只能是user、assistant、user、assistant的格式
raise Exception("fastllm requires that the prompt words must appear alternately in the roles of user and assistant.")
elif msg.role == "assistant":
if i - 1 < 0:
raise Exception("fastllm Not Support assistant prompt in first message")
else:
pre_msg = conversation[i - 1]
if pre_msg.role != "user":
raise Exception("In FastLLM, The message role before the assistant msg must be user")
else:
raise NotImplementedError(f"prompt role {msg.role } not supported yet")

last_msg = conversation[-1]
if last_msg.role != "user":
raise Exception("last msg role must be user")
query = last_msg.content

except Exception as e:
logging.error("Error in applying chat template from request: %s", e)
traceback.print_exc()
return self.create_error_response(str(e))

request_id = f"fastllm-{self.model_name}-{random_uuid()}"

frequency_penalty = 1.0
if request.frequency_penalty and request.frequency_penalty != 0.0:
frequency_penalty = request.frequency_penalty

max_length = request.max_tokens if request.max_tokens else 8192
logging.info(request)
logging.info(f"fastllm input: {query}")
logging.info(f"fastllm history: {history}")
# stream_response 中的结果不包含token的统计信息
result_generator = self.model.stream_response(query, history,
max_length = max_length, do_sample = True,
top_p = request.top_p, top_k = request.top_k, temperature = request.temperature,
repeat_penalty = frequency_penalty, one_by_one = True)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
except ValueError as e:
return self.create_error_response(str(e))

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator,
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.model_name
created_time = int(time.time())
result = ""
for res in result_generator:
result += res

choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=result),
logprobs=None,
finish_reason='stop',
)

# TODO: 补充usage信息, 包括prompt token数, 生成的tokens数量
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
usage=UsageInfo(),
)

return response


async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str) -> AsyncGenerator[str, None]:
model_name = self.model_name
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"

# TODO: 支持request.n 和 request.echo配置
first_iteration = True
try:
if first_iteration:
# 1. role部分
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
first_iteration = False

# 2. content部分
for res in result_generator:
delta_text = res
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta_text),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# 3. 结束标志
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
logprobs=None,
finish_reason='stop')
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
yield f"data: {data}\n\n"
except ValueError as e:
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"

def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str

70 changes: 70 additions & 0 deletions example/openai_server/openai_api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
import fastapi
import logging
import sys
import uvicorn
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

from protocal.openai_protocol import *
from fastllm_completion import FastLLmCompletion

def parse_args():
parser = argparse.ArgumentParser(description="OpenAI-compatible API server")
parser.add_argument("--model_name", type = str, help = "部署的模型名称, 调用api时会进行名称核验", required=True)
parser.add_argument("--host", type = str, default="0.0.0.0", help = "API Server host")
parser.add_argument("--port", type = int, default = 8080, help = "API Server Port")
parser.add_argument('-p', '--path', type = str, required = True, default = '', help = 'fastllm model path(end with .flm)')
parser.add_argument('-t', '--threads', type=int, default=4, help='fastllm model 使用的线程数量')
parser.add_argument('-l', '--low', action='store_true', help='fastllm 是否使用低内存模式')
return parser.parse_args()


app = fastapi.FastAPI()
# 设置允许的请求来源, 生产环境请做对应变更
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

fastllm_completion:FastLLmCompletion

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await fastllm_completion.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())

def init_logging(log_level = logging.INFO, log_file:str = None):
logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
root = logging.getLogger()
root.setLevel(log_level)
if log_file is not None:
logging.basicConfig(level=log_level, filemode='a', filename=log_file, format=logging_format)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(logging.Formatter(logging_format))
root.addHandler(stdout_handler)


if __name__ == "__main__":
init_logging()
args = parse_args()
logging.info(args)
fastllm_completion = FastLLmCompletion(model_name = args.model_name,
model_path = args.path,
low_mem_mode = args.low,
cpu_thds = args.threads)
uvicorn.run(app, host=args.host, port=args.port)
Empty file.
Loading

0 comments on commit f25be77

Please sign in to comment.