Skip to content

Commit

Permalink
Merge pull request #114 from Undertone0809/feat-v1.10.0
Browse files Browse the repository at this point in the history
pref: optimize chat client
  • Loading branch information
Undertone0809 authored Nov 5, 2023
2 parents afd3a0d + e93fd3a commit 3e5cffa
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 46 deletions.
17 changes: 10 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
SHELL := /usr/bin/env bash
OS := $(shell python -c "import sys; print(sys.platform)")

ifeq ($(OS),win32)
PYTHONPATH := $(shell python -c "import os; print(os.getcwd())")
TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
else
PYTHONPATH := `pwd`
TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
endif

#* Poetry
.PHONY: poetry-download
poetry-download:
Expand All @@ -24,13 +33,7 @@ polish-codestyle:
.PHONY: formatting
formatting: polish-codestyle

ifeq ($(OS),win32)
PYTHONPATH := $(shell python -c "import os; print(os.getcwd())")
TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
else
PYTHONPATH := `pwd`
TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
endif


#* Linting
.PHONY: test
Expand Down
4 changes: 2 additions & 2 deletions docs/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 46 additions & 11 deletions promptulate/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import click
import questionary

from promptulate import Conversation
from promptulate.agents import ToolAgent
from promptulate.config import Config
from promptulate.llms import BaseLLM, ChatOpenAI, ErnieBot
from promptulate.schema import LLMType
from promptulate.schema import LLMType, MessageSet, SystemMessage
from promptulate.tools import (
ArxivQueryTool,
Calculator,
Expand All @@ -37,8 +37,10 @@
SleepTool,
)
from promptulate.tools.shell import ShellTool
from promptulate.utils import print_text, set_proxy_mode
from promptulate.utils.color_print import print_text
from promptulate.utils.proxy import set_proxy_mode

CFG = Config()
MODEL_MAPPING = {"OpenAI": ChatOpenAI, "ErnieBot": ErnieBot}
TOOL_MAPPING = {
"Calculator": Calculator,
Expand All @@ -51,28 +53,56 @@
}


def check_key(model_type: str):
model_key_mapping = {
"OpenAI": CFG.get_openai_api_key,
"ErnieBot": CFG.get_ernie_api_key,
}
model_key_mapping[model_type]()


def get_user_input() -> Optional[str]:
marker = (
"# You input are here (please delete this line)\n"
"You should save it and close the notebook after writing the prompt. (please delete this line)\n"
"Reply 'exit' to exit the chat.\n"
)
message = click.edit(marker)
if message is not None:
return message
return None

if message == "exit":
exit()

return message


def get_user_openai_api_key():
import os

api_key = questionary.password("Please enter your OpenAI API Key: ").ask()
os.environ["OPENAI_API_KEY"] = api_key


def simple_chat(llm: BaseLLM):
conversation = Conversation(llm=llm)
messages = MessageSet(
messages=[
SystemMessage(content="You are a helpful assistant."),
]
)

while True:
print_text("[User] ", "blue")
prompt = get_user_input()

if not prompt:
ValueError("Your prompt is None, please input something.")

print_text(prompt, "blue")
ret = conversation.run(prompt)
print_text(f"[output] {ret}", "green")
messages.add_user_message(prompt)

answer = llm.predict(messages)
messages.add_message(answer)

print_text(f"[output] {answer.content}", "green")


def web_chat(llm: BaseLLM):
Expand Down Expand Up @@ -110,6 +140,7 @@ def agent_chat(agent: ToolAgent):


def chat():
# get parameters
parser = argparse.ArgumentParser(
description="Welcome to Promptulate Chat - The best chat terminal ever!"
)
Expand All @@ -131,14 +162,18 @@ def chat():

terminal_mode = questionary.select(
"Choose a chat terminal:",
choices=["Simple Chat", "Agent Chat", "Web Agent Chat"],
choices=["Simple Chat", "Agent Chat", "Web Agent Chat", "exit"],
).ask()

if terminal_mode == "exit":
exit(0)

model = questionary.select(
"Choose a llm model:",
choices=list(MODEL_MAPPING.keys()),
).ask()
# todo check whether exist llm key

check_key(model)

llm = MODEL_MAPPING[model](temperature=0.2)
if terminal_mode == "Simple Chat":
Expand Down
2 changes: 0 additions & 2 deletions promptulate/frameworks/conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pydantic import Field, validator

from promptulate import utils
from promptulate.config import Config
from promptulate.frameworks.schema import BasePromptFramework
from promptulate.llms import ChatOpenAI
from promptulate.llms.base import BaseLLM
Expand All @@ -44,7 +43,6 @@
)
from promptulate.tips import EmptyMessageSetError

CFG = Config()
logger = utils.get_logger()


Expand Down
10 changes: 6 additions & 4 deletions promptulate/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pydantic import BaseModel

from promptulate.hook import Hook, HookTable
from promptulate.schema import BaseMessage, LLMType, MessageSet
from promptulate.schema import AssistantMessage, LLMType, MessageSet


class BaseLLM(BaseModel, ABC):
Expand All @@ -41,15 +41,17 @@ def __init__(self, *args, **kwargs):
Hook.mount_instance_hook(hook, self)
Hook.call_hook(HookTable.ON_LLM_CREATE, self, **kwargs)

def predict(self, prompts: MessageSet, *args, **kwargs) -> BaseMessage:
def predict(self, prompts: MessageSet, *args, **kwargs) -> AssistantMessage:
"""llm generate prompt"""
Hook.call_hook(HookTable.ON_LLM_START, self, prompts, *args, **kwargs)
result: BaseMessage = self._predict(prompts, *args, **kwargs)
result: AssistantMessage = self._predict(prompts, *args, **kwargs)
Hook.call_hook(HookTable.ON_LLM_RESULT, self, result=result.content)
return result

@abstractmethod
def _predict(self, prompts: MessageSet, *args, **kwargs) -> Optional[BaseMessage]:
def _predict(
self, prompts: MessageSet, *args, **kwargs
) -> Optional[AssistantMessage]:
"""Run the llm, implemented through subclass."""

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion promptulate/llms/erniebot/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs):

def _predict(
self, prompts: MessageSet, stop: Optional[List[str]] = None, *args, **kwargs
) -> BaseMessage:
) -> AssistantMessage:
"""llm generate prompt"""
headers = {"Content-Type": "application/json"}
if self.model == "ernie-bot-turbo":
Expand Down
29 changes: 10 additions & 19 deletions promptulate/tools/iot_swith_mqtt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from promptulate.llms import BaseLLM, ChatOpenAI
from promptulate.tools import Tool
from promptulate.tools.iot_swith_mqtt.api_wrapper import IotSwitchAPIWrapper
from promptulate.tools.iot_swith_mqtt.prompt import prompt_template
from promptulate.utils import StringTemplate, get_logger

Expand All @@ -24,25 +23,18 @@ class IotSwitchTool(Tool):
rule_table: List[Dict]

def __init__(
self,
client,
llm: BaseLLM = None,
rule_table: List[Dict] = None,
api_wrapper: IotSwitchAPIWrapper = IotSwitchAPIWrapper(),
**kwargs,
self, client, llm: BaseLLM = None, rule_table: List[Dict] = None, **kwargs
):
"""
Args:
llm: BaseLLM
client: mqtt.Client
rule_table: List[Dict]
api_wrapper: IotSwitchAPIWrapper
llm(BaseLLM): llm, default is ChatOpenAI
client(paho.mqtt.client.Client): paho mqtt client which has connected.
rule_table(List[Dict]):
"""
self.api_wrapper = api_wrapper
self.client = client
self.llm: BaseLLM = llm or ChatOpenAI(
temperature=0.1, enable_default_system_prompt=False
)
self.client = client
self.rule_table = rule_table

super().__init__(**kwargs)
Expand All @@ -56,6 +48,7 @@ def _run(self, question: str, *args, **kwargs) -> str:
"This is needed in order to for IotSwitchTool. "
"Please install it with `pip install paho-mqtt`."
)

if len(self.rule_table) == 0:
raise ValueError("rule_table is empty")
else:
Expand All @@ -73,9 +66,7 @@ def _process_llm_result(self, llm_output: str) -> str:
if len(answer) == 0:
return "failure information :" + llm_output
else:
self.api_wrapper.run(
self.client,
self.rule_table[int(answer[0]) - 1]["topic"],
self.rule_table[int(answer[0]) - 1]["ask"],
)
return "ok"
topic = self.rule_table[int(answer[0]) - 1]["topic"]
command = self.rule_table[int(answer[0]) - 1]["ask"]
self.client.publish(topic, command)
return "success"

0 comments on commit 3e5cffa

Please sign in to comment.