-
Notifications
You must be signed in to change notification settings - Fork 379
/
Copy pathentry_point.py
46 lines (39 loc) · 2.17 KB
/
entry_point.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import argparse
from discord_app import discord_main
from app import gradio_main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
app_mode = os.getenv("LLMCHAT_APP_MODE")
local_files_only = os.getenv("LLMCHAT_LOCAL_FILES_ONLY")
serper_api_key = os.getenv("LLMCHAT_SERPER_API_KEY")
if app_mode is None or \
app_mode not in ["GRADIO", "DISCORD"]:
app_mode = "GRADIO"
if local_files_only is None:
local_files_only = False
else:
local_files_only = bool(local_files_only)
if app_mode == "GRADIO":
parser.add_argument('--root-path', default="")
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--serper-api-key', default=serper_api_key, type=str)
args = parser.parse_args()
gradio_main(args)
elif app_mode == "DISCORD":
parser.add_argument('--token', default=None, type=str)
parser.add_argument('--model-name', default=None, type=str)
parser.add_argument('--max-workers', default=1, type=int)
parser.add_argument('--mode-cpu', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--mode-mps', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--mode-8bit', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction)
parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction)
parser.add_argument('--serper-api-key', default=serper_api_key, type=str)
parser.add_argument('--tgi-server-addr', default=None, type=str)
parser.add_argument('--tgi-server-port', default=None, type=str)
args = parser.parse_args()
discord_main(args)