Skip to content

Commit

Permalink
Feat: OpenAIrobot类增加微软azure支持
Browse files Browse the repository at this point in the history
  • Loading branch information
mawwalker committed Mar 9, 2024
1 parent e409a4f commit d5e8171
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
49 changes: 36 additions & 13 deletions robot/AI.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def chat(self, texts, parsed):
logger.critical("AnyQ robot failed to response for %r", msg, exc_info=True)
return "抱歉, AnyQ回答失败"


class OPENAIRobot(AbstractRobot):

SLUG = "openai"
Expand All @@ -235,6 +234,8 @@ def __init__(
self,
openai_api_key,
model,
provider,
api_version,
temperature,
max_tokens,
top_p,
Expand Down Expand Up @@ -267,6 +268,8 @@ def __init__(
logger.critical("OpenAI 初始化失败,请升级 Python 版本至 > 3.6")
self.model = model
self.prefix = prefix
self.provider = provider
self.api_version = api_version
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
Expand Down Expand Up @@ -295,12 +298,20 @@ def stream_chat(self, texts):

header = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.openai.api_key,
# "Authorization": "Bearer " + self.openai.api_key,
}
if self.provider == 'openai':
header['Authorization'] = "Bearer " + self.openai.api_key,
elif self.provider == 'azure':
header['api-key'] = self.openai.api_key
else:
raise ValueError("Please check your config file, OpenAiRobot's provider should be openai or azure.")

data = {"model": self.model, "messages": self.context, "stream": True}
logger.info(f"使用模型:{self.model},开始流式请求")
url = self.api_base + "/completions"
if self.provider == 'azure':
url = f"{self.api_base}/openai/deployments/{self.model}/chat/completions?api-version={self.api_version}"
# 请求接收流式数据
try:
response = requests.request(
Expand Down Expand Up @@ -368,17 +379,29 @@ def chat(self, texts, parsed):
try:
respond = ""
self.context.append({"role": "user", "content": msg})
response = self.openai.Completion.create(
model=self.model,
messages=self.context,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop_ai,
api_base=self.api_base
)
if self.provider == "openai":
response = self.openai.Completion.create(
model=self.model,
messages=self.context,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop_ai,
api_base=self.api_base
)
else:
from openai import AzureOpenAI
client = AzureOpenAI(
azure_endpoint = self.api_base,
api_key=self.openai_api_key,
api_version=self.api_version
)
response = client.chat.completions.create(
model=self.model,
messages=self.context
)
message = response.choices[0].message
respond = message.content
self.context.append(message)
Expand Down
2 changes: 2 additions & 0 deletions static/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ tuling:
# 注册一个账号,获得 openai_api_key 后填到下面的配置中即可
openai:
openai_api_key: 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxx'
provider: 'azure' # openai的接口填写openai, azure的填写azure
api_version: '2023-05-15' # 如果是openai的,留空就行,azure的需填写对应的api_version,参考官方文档
# 参数指定将生成文本的模型类型。目前支持 gpt-3.5-turbo 和 gpt-3.5-turbo-0301 两种选择
model: 'gpt-3.5-turbo'
# 在前面加的一段前缀
Expand Down

0 comments on commit d5e8171

Please sign in to comment.