Skip to content

Commit

Permalink
更新webui
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 9, 2024
1 parent f663549 commit c494755
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 6 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --dtype int8
# 需要安装依赖: pip install -r requirements-server.txt
# 这里在8080端口打开了一个模型名为qwen的server
python3 -m ftllm.server -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 --model_name qwen

# webui
# 需要安装依赖: pip install streamlit-chat
python3 -m ftllm.webui -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080
```

以上demo均可使用参数 --help 查看详细参数
Expand Down
4 changes: 4 additions & 0 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --dtype int8
# Requires dependencies: pip install -r requirements-server.txt
# Opens a server named 'qwen' on port 8080
python3 -m ftllm.server -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 --model_name qwen

# webui
# Requires dependencies: pip install streamlit-chat
python3 -m ftllm.webui -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080
```

Detailed parameters can be viewed using the --help argument for all demos.
Expand Down
69 changes: 69 additions & 0 deletions tools/fastllm_pytools/web_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

from ftllm import llm
import sys
import os
import argparse
from util import make_normal_parser
from util import make_normal_llm_model

def parse_args():
parser = make_normal_parser("fastllm webui")
parser.add_argument("--port", type = int, default = 8080, help = "API server port")
parser.add_argument("--title", type = str, default = "fastllm webui", help = "页面标题")
return parser.parse_args()

args = parse_args()

import streamlit as st
from streamlit_chat import message
st.set_page_config(
page_title = args.title,
page_icon = ":robot:"
)

@st.cache_resource
def get_model():
args = parse_args()
model = make_normal_llm_model(args)
return model

if "messages" not in st.session_state:
st.session_state.messages = []

max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01)
top_k = st.sidebar.slider("top_k", 1, 100, 1, step = 1)
temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step = 0.01)
repeat_penalty = st.sidebar.slider("repeat_penalty", 1.0, 10.0, 1.0, step = 0.05)

buttonClean = st.sidebar.button("清理会话历史", key="clean")
if buttonClean:
st.session_state.messages = []
st.rerun()

for i, (prompt, response) in enumerate(st.session_state.messages):
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
st.markdown(response)

if prompt := st.chat_input("请开始对话"):
model = get_model()
with st.chat_message("user"):
st.markdown(prompt)

with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for chunk in model.stream_response(prompt,
st.session_state.messages,
max_length = max_new_tokens,
top_k = top_k,
top_p = top_p,
temperature = temperature,
repeat_penalty = repeat_penalty,
one_by_one = True):
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
st.session_state.messages.append((prompt, full_response))
20 changes: 20 additions & 0 deletions tools/fastllm_pytools/webui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
try:
import streamlit as st
except:
print("Plase install streamlit-chat. (pip install streamlit-chat)")
exit(0)

import os
import sys

if __name__ == "__main__":
current_path = os.path.dirname(os.path.abspath(__file__))
web_demo_path = os.path.join(current_path, 'web_demo.py')
port = ""
for i in range(len(sys.argv)):
if sys.argv[i] == "--port":
port = "--server.port " + sys.argv[i + 1]
if sys.argv[i] == "--help" or sys.argv[i] == "-h":
os.system("python3 " + web_demo_path + " --help")
exit(0)
os.system("streamlit run " + port + " " + web_demo_path + ' -- ' + ' '.join(sys.argv[1:]))
71 changes: 65 additions & 6 deletions tools/scripts/web_demo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,73 @@
import streamlit as st
from streamlit_chat import message

from ftllm import llm
import sys
import os
import argparse

def make_normal_parser(des: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description = des)
parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型路径,fastllm模型文件或HF模型文件夹')
parser.add_argument('-t', '--threads', type = int, default = 4, help = '线程数量')
parser.add_argument('-l', '--low', action = 'store_true', help = '是否使用低内存模式')
parser.add_argument('--dtype', type = str, default = "float16", help = '权重类型(读取HF模型时有效)')
parser.add_argument('--atype', type = str, default = "float32", help = '推理类型,可使用float32或float16')
parser.add_argument('--cuda_embedding', action = 'store_true', help = '在cuda上进行embedding')
parser.add_argument('--device', type = str, help = '使用的设备')
return parser

def parse_args():
parser = make_normal_parser("fastllm webui")
parser.add_argument("--port", type = int, default = 8080, help = "API server port")
parser.add_argument("--title", type = str, default = "fastllm webui", help = "页面标题")
return parser.parse_args()

def make_normal_llm_model(args):
if (args.device and args.device != ""):
try:
import ast
device_map = ast.literal_eval(args.device)
if (isinstance(device_map, list) or isinstance(device_map, dict)):
llm.set_device_map(device_map)
else:
llm.set_device_map(args.device)
except:
llm.set_device_map(args.device)
llm.set_cpu_threads(args.threads)
llm.set_cpu_low_mem(args.low)
if (args.cuda_embedding):
llm.set_cuda_embedding(True)
model = llm.model(args.path, dtype = args.dtype, tokenizer_type = "auto")
model.set_atype(args.atype)
return model

args = parse_args()
import streamlit as st
from streamlit_chat import message
st.set_page_config(
page_title="fastllm web demo",
page_icon=":robot:"
page_title = args.title,
page_icon = ":robot:"
)

@st.cache_resource
def get_model():
model = llm.model(sys.argv[1])
args = parse_args()
model = make_normal_llm_model(args)
return model

if "messages" not in st.session_state:
st.session_state.messages = []

max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01)
top_k = st.sidebar.slider("top_k", 1, 100, 1, step = 1)
temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step = 0.01)
repeat_penalty = st.sidebar.slider("repeat_penalty", 1.0, 10.0, 1.0, step = 0.05)

buttonClean = st.sidebar.button("清理会话历史", key="clean")
if buttonClean:
st.session_state.messages = []
st.rerun()

for i, (prompt, response) in enumerate(st.session_state.messages):
with st.chat_message("user"):
st.markdown(prompt)
Expand All @@ -30,7 +82,14 @@ def get_model():
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for chunk in model.stream_response(prompt, st.session_state.messages, one_by_one = True):
for chunk in model.stream_response(prompt,
st.session_state.messages,
max_length = max_new_tokens,
top_k = top_k,
top_p = top_p,
temperature = temperature,
repeat_penalty = repeat_penalty,
one_by_one = True):
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
Expand Down

0 comments on commit c494755

Please sign in to comment.