From e93fd3a1d5f4db39889b847632326e3c38b195aa Mon Sep 17 00:00:00 2001 From: Zeeland <287017217@qq.com> Date: Sun, 5 Nov 2023 22:19:52 +0800 Subject: [PATCH] pref: optimize chat client --- Makefile | 17 +++--- docs/images/coverage.svg | 4 +- promptulate/client/chat.py | 57 +++++++++++++++---- .../frameworks/conversation/conversation.py | 2 - promptulate/llms/base.py | 10 ++-- promptulate/llms/erniebot/erniebot.py | 2 +- promptulate/tools/iot_swith_mqtt/tools.py | 29 ++++------ 7 files changed, 75 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index e5a4d15b..fdaa9853 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,14 @@ +SHELL := /usr/bin/env bash OS := $(shell python -c "import sys; print(sys.platform)") +ifeq ($(OS),win32) + PYTHONPATH := $(shell python -c "import os; print(os.getcwd())") + TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter +else + PYTHONPATH := `pwd` + TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter +endif + #* Poetry .PHONY: poetry-download poetry-download: @@ -24,13 +33,7 @@ polish-codestyle: .PHONY: formatting formatting: polish-codestyle -ifeq ($(OS),win32) - PYTHONPATH := $(shell python -c "import os; print(os.getcwd())") - TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter -else - PYTHONPATH := `pwd` - TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter -endif + #* Linting .PHONY: test diff --git a/docs/images/coverage.svg b/docs/images/coverage.svg index 4f8c1853..9d027c7d 100644 --- a/docs/images/coverage.svg +++ b/docs/images/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 50% - 50% + 49% + 49% diff --git a/promptulate/client/chat.py b/promptulate/client/chat.py index 6a594d63..82794cd8 100644 --- a/promptulate/client/chat.py +++ b/promptulate/client/chat.py @@ -24,10 +24,10 @@ import click import questionary -from promptulate import Conversation from promptulate.agents import ToolAgent +from promptulate.config import Config from promptulate.llms import BaseLLM, ChatOpenAI, ErnieBot -from promptulate.schema import LLMType +from promptulate.schema import LLMType, MessageSet, SystemMessage from promptulate.tools import ( ArxivQueryTool, Calculator, @@ -37,8 +37,10 @@ SleepTool, ) from promptulate.tools.shell import ShellTool -from promptulate.utils import print_text, set_proxy_mode +from promptulate.utils.color_print import print_text +from promptulate.utils.proxy import set_proxy_mode +CFG = Config() MODEL_MAPPING = {"OpenAI": ChatOpenAI, "ErnieBot": ErnieBot} TOOL_MAPPING = { "Calculator": Calculator, @@ -51,28 +53,56 @@ } +def check_key(model_type: str): + model_key_mapping = { + "OpenAI": CFG.get_openai_api_key, + "ErnieBot": CFG.get_ernie_api_key, + } + model_key_mapping[model_type]() + + def get_user_input() -> Optional[str]: marker = ( "# You input are here (please delete this line)\n" "You should save it and close the notebook after writing the prompt. (please delete this line)\n" + "Reply 'exit' to exit the chat.\n" ) message = click.edit(marker) - if message is not None: - return message - return None + + if message == "exit": + exit() + + return message + + +def get_user_openai_api_key(): + import os + + api_key = questionary.password("Please enter your OpenAI API Key: ").ask() + os.environ["OPENAI_API_KEY"] = api_key def simple_chat(llm: BaseLLM): - conversation = Conversation(llm=llm) + messages = MessageSet( + messages=[ + SystemMessage(content="You are a helpful assistant."), + ] + ) while True: print_text("[User] ", "blue") prompt = get_user_input() + if not prompt: ValueError("Your prompt is None, please input something.") + print_text(prompt, "blue") - ret = conversation.run(prompt) - print_text(f"[output] {ret}", "green") + messages.add_user_message(prompt) + + answer = llm.predict(messages) + messages.add_message(answer) + + print_text(f"[output] {answer.content}", "green") def web_chat(llm: BaseLLM): @@ -110,6 +140,7 @@ def agent_chat(agent: ToolAgent): def chat(): + # get parameters parser = argparse.ArgumentParser( description="Welcome to Promptulate Chat - The best chat terminal ever!" ) @@ -131,14 +162,18 @@ def chat(): terminal_mode = questionary.select( "Choose a chat terminal:", - choices=["Simple Chat", "Agent Chat", "Web Agent Chat"], + choices=["Simple Chat", "Agent Chat", "Web Agent Chat", "exit"], ).ask() + if terminal_mode == "exit": + exit(0) + model = questionary.select( "Choose a llm model:", choices=list(MODEL_MAPPING.keys()), ).ask() - # todo check whether exist llm key + + check_key(model) llm = MODEL_MAPPING[model](temperature=0.2) if terminal_mode == "Simple Chat": diff --git a/promptulate/frameworks/conversation/conversation.py b/promptulate/frameworks/conversation/conversation.py index eac5e774..cf07b3e6 100644 --- a/promptulate/frameworks/conversation/conversation.py +++ b/promptulate/frameworks/conversation/conversation.py @@ -23,7 +23,6 @@ from pydantic import Field, validator from promptulate import utils -from promptulate.config import Config from promptulate.frameworks.schema import BasePromptFramework from promptulate.llms import ChatOpenAI from promptulate.llms.base import BaseLLM @@ -44,7 +43,6 @@ ) from promptulate.tips import EmptyMessageSetError -CFG = Config() logger = utils.get_logger() diff --git a/promptulate/llms/base.py b/promptulate/llms/base.py index 85a971e8..a15c4a08 100644 --- a/promptulate/llms/base.py +++ b/promptulate/llms/base.py @@ -23,7 +23,7 @@ from pydantic import BaseModel from promptulate.hook import Hook, HookTable -from promptulate.schema import BaseMessage, LLMType, MessageSet +from promptulate.schema import AssistantMessage, LLMType, MessageSet class BaseLLM(BaseModel, ABC): @@ -41,15 +41,17 @@ def __init__(self, *args, **kwargs): Hook.mount_instance_hook(hook, self) Hook.call_hook(HookTable.ON_LLM_CREATE, self, **kwargs) - def predict(self, prompts: MessageSet, *args, **kwargs) -> BaseMessage: + def predict(self, prompts: MessageSet, *args, **kwargs) -> AssistantMessage: """llm generate prompt""" Hook.call_hook(HookTable.ON_LLM_START, self, prompts, *args, **kwargs) - result: BaseMessage = self._predict(prompts, *args, **kwargs) + result: AssistantMessage = self._predict(prompts, *args, **kwargs) Hook.call_hook(HookTable.ON_LLM_RESULT, self, result=result.content) return result @abstractmethod - def _predict(self, prompts: MessageSet, *args, **kwargs) -> Optional[BaseMessage]: + def _predict( + self, prompts: MessageSet, *args, **kwargs + ) -> Optional[AssistantMessage]: """Run the llm, implemented through subclass.""" @abstractmethod diff --git a/promptulate/llms/erniebot/erniebot.py b/promptulate/llms/erniebot/erniebot.py index 58271286..0a85d868 100644 --- a/promptulate/llms/erniebot/erniebot.py +++ b/promptulate/llms/erniebot/erniebot.py @@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs): def _predict( self, prompts: MessageSet, stop: Optional[List[str]] = None, *args, **kwargs - ) -> BaseMessage: + ) -> AssistantMessage: """llm generate prompt""" headers = {"Content-Type": "application/json"} if self.model == "ernie-bot-turbo": diff --git a/promptulate/tools/iot_swith_mqtt/tools.py b/promptulate/tools/iot_swith_mqtt/tools.py index baf0fd06..087b49eb 100644 --- a/promptulate/tools/iot_swith_mqtt/tools.py +++ b/promptulate/tools/iot_swith_mqtt/tools.py @@ -3,7 +3,6 @@ from promptulate.llms import BaseLLM, ChatOpenAI from promptulate.tools import Tool -from promptulate.tools.iot_swith_mqtt.api_wrapper import IotSwitchAPIWrapper from promptulate.tools.iot_swith_mqtt.prompt import prompt_template from promptulate.utils import StringTemplate, get_logger @@ -24,25 +23,18 @@ class IotSwitchTool(Tool): rule_table: List[Dict] def __init__( - self, - client, - llm: BaseLLM = None, - rule_table: List[Dict] = None, - api_wrapper: IotSwitchAPIWrapper = IotSwitchAPIWrapper(), - **kwargs, + self, client, llm: BaseLLM = None, rule_table: List[Dict] = None, **kwargs ): """ Args: - llm: BaseLLM - client: mqtt.Client - rule_table: List[Dict] - api_wrapper: IotSwitchAPIWrapper + llm(BaseLLM): llm, default is ChatOpenAI + client(paho.mqtt.client.Client): paho mqtt client which has connected. + rule_table(List[Dict]): """ - self.api_wrapper = api_wrapper + self.client = client self.llm: BaseLLM = llm or ChatOpenAI( temperature=0.1, enable_default_system_prompt=False ) - self.client = client self.rule_table = rule_table super().__init__(**kwargs) @@ -56,6 +48,7 @@ def _run(self, question: str, *args, **kwargs) -> str: "This is needed in order to for IotSwitchTool. " "Please install it with `pip install paho-mqtt`." ) + if len(self.rule_table) == 0: raise ValueError("rule_table is empty") else: @@ -73,9 +66,7 @@ def _process_llm_result(self, llm_output: str) -> str: if len(answer) == 0: return "failure information :" + llm_output else: - self.api_wrapper.run( - self.client, - self.rule_table[int(answer[0]) - 1]["topic"], - self.rule_table[int(answer[0]) - 1]["ask"], - ) - return "ok" + topic = self.rule_table[int(answer[0]) - 1]["topic"] + command = self.rule_table[int(answer[0]) - 1]["ask"] + self.client.publish(topic, command) + return "success"