Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: AgentLegoToolkit #164

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
14 changes: 10 additions & 4 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.actions import PPT, ActionExecutor, ArxivSearch, BINGMap, GoogleScholar, IPythonInterpreter
# from lagent.actions.agentlego_wrapper import AgentLegoToolkit
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
Expand All @@ -23,6 +24,11 @@ def init_state(self):

action_list = [
ArxivSearch(),
PPT(),
BINGMap(key='Your api key' # noqa
),
GoogleScholar(api_key='Your api key' # noqa
)
]
st.session_state['plugin_map'] = {
action.name: action
Expand Down Expand Up @@ -104,7 +110,7 @@ def setup_sidebar(self):
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.meta_prompt = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
Expand Down Expand Up @@ -141,8 +147,8 @@ def initialize_chatbot(self, model, plugin_action):
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7)
)),
max_turn=15)

def render_user(self, prompt: str):
with st.chat_message('user'):
Expand Down
48 changes: 48 additions & 0 deletions lagent/actions/agentlego_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

# from agentlego.parsers import DefaultParser
from agentlego.tools.remote import RemoteTool

from lagent import BaseAction
from lagent.actions.parser import JsonParser


class AgentLegoToolkit(BaseAction):

def __init__(self,
name: str,
url: Optional[str] = None,
text: Optional[str] = None,
spec_dict: Optional[dict] = None,
parser=JsonParser,
enable: bool = True):

if url is not None:
spec = dict(url=url)
elif text is not None:
spec = dict(text=text)
else:
assert spec_dict is not None
spec = dict(spec_dict=spec_dict)
if url is not None and not url.endswith('.json'):
api_list = [RemoteTool.from_url(url).to_lagent()]
else:
api_list = [
api.to_lagent() for api in RemoteTool.from_openapi(**spec)
]
api_desc = []
for api in api_list:
api_desc.append(api.description)
if len(api_list) > 1:
tool_description = dict(name=name, api_list=api_desc)
for func in api_list:
setattr(self, func.name, func.run)
else:
tool_description = api_desc[0]
setattr(self, 'run', api_list[0].run)
super().__init__(
description=tool_description, parser=parser, enable=enable)

@property
def is_toolkit(self):
return 'api_list' in self.description
4 changes: 2 additions & 2 deletions lagent/actions/builtin_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
action_return = ActionReturn(
url=None,
args=dict(text=err_msg),
errmsg=err_msg or self._err_msg,
errmsg=str(err_msg) or self._err_msg,
type=self.name,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
Expand Down Expand Up @@ -76,7 +76,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
url=None,
args=dict(text=err_msg),
type=self.name,
errmsg=err_msg or self._err_msg,
errmsg=str(err_msg) or self._err_msg,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
return action_return
Expand Down
30 changes: 20 additions & 10 deletions lagent/agents/internlm2_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def format_plugin(message):
else:
new_message.append(
dict(role=message['role'], content=new_content))

return new_message

def format(self,
Expand All @@ -125,9 +124,10 @@ def format(self,
if self.meta_prompt:
formatted.append(dict(role='system', content=self.meta_prompt))
if interpreter_executor and self.interpreter_prompt:
interpreter_info = interpreter_executor.get_actions_info()[0]
interpreter_prompt = self.interpreter_prompt.format(
code_prompt=interpreter_info['description'])
# interpreter_info = interpreter_executor.get_actions_info()[0]
# interpreter_prompt = self.interpreter_prompt.format(
# code_prompt=interpreter_info['description'])
interpreter_prompt = self.interpreter_prompt
formatted.append(
dict(
role='system',
Expand Down Expand Up @@ -169,20 +169,30 @@ def parse(self, message, plugin_executor: ActionExecutor,
action = action.split(self.tool['end'].strip())[0]
return 'plugin', message, action
if self.tool['name_map']['interpreter'] in message:
message, code = message.split(
f"{self.tool['start_token']}"
f"{self.tool['name_map']['interpreter']}")
try:
message, code, *_ = message.split(
f"{self.tool['start_token']}"
f"{self.tool['name_map']['interpreter']}")
# message, code, *_ = message.split(f"{self.tool['start_token']}")
# _, code, *_ = code.split(f"{self.tool['name_map']['interpreter']}")
except ValueError:
message, code, *_ = message.split(
self.tool['name_map']['interpreter'])
tool_start_idx = message.rfind(self.tool['start_token'])
if tool_start_idx != -1:
message = message[:tool_start_idx]
message = message.strip()
code = code.split(self.tool['end'].strip())[0].strip()
return 'interpreter', message, dict(
name=interpreter_executor.action_names()[0],
parameters=dict(command=code))
name='IPythonInterpreter', parameters=dict(
command=code)) if interpreter_executor else None
return None, message.split(self.tool['start_token'])[0], None

def format_response(self, action_return, name) -> dict:
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
response = str(action_return.errmsg)
content = self.execute['begin'] + response + self.execute['end']
if self.execute.get('fallback_role'):
return dict(
Expand Down
87 changes: 66 additions & 21 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
Expand All @@ -10,6 +11,8 @@

from .base_api import BaseAPIModel

warnings.simplefilter('default')

OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'


Expand Down Expand Up @@ -45,6 +48,7 @@ def __init__(self,
model_type: str = 'gpt-3.5-turbo',
query_per_second: int = 1,
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
Expand All @@ -53,13 +57,19 @@ def __init__(self,
dict(role='assistant', api_role='assistant')
],
openai_api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
DeprecationWarning)
gen_params.pop('top_k')
super().__init__(
model_type=model_type,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry,
**gen_params)
self.gen_params.pop('top_k')
self.logger = getLogger(__name__)

if isinstance(key, str):
Expand All @@ -79,16 +89,8 @@ def __init__(self,
self.org_ctr = 0
self.url = openai_api_base
self.model_type = model_type

# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.model_type:
context_window = 32768
elif '16k' in self.model_type:
context_window = 16384
elif 'gpt-4' in self.model_type:
context_window = 8192
self.context_window = context_window
self.proxies = proxies
self.json_mode = json_mode

def chat(
self,
Expand Down Expand Up @@ -118,6 +120,27 @@ def chat(
ret = [task.result() for task in tasks]
return ret[0] if isinstance(inputs[0], dict) else ret

def stream_chat(
self,
inputs: List[dict],
**gen_params,
) -> str:
"""Generate responses given the contexts.

Args:
inputs (List[dict]): a list of messages
gen_params: additional generation configuration

Returns:
str: generated string
"""
assert isinstance(inputs, list)
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
gen_params['stream'] = True
yield from self._chat(inputs, **gen_params)

def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.

Expand All @@ -132,9 +155,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(
gen_params.pop('max_new_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''

Expand Down Expand Up @@ -170,27 +191,49 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params.pop('stop_words'),
frequency_penalty=gen_params.pop('repetition_penalty'),
**gen_params,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
self.url,
headers=header,
data=json.dumps(data),
proxies=self.proxies)
if 'stream' not in data or not data['stream']:
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
else:
resp = ''
for chunk in raw_response.iter_lines(
chunk_size=8192, decode_unicode=False,
delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
return
if decoded[:6] == 'data: ':
decoded = decoded[6:]
response = json.loads(decoded)
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
resp += choice['delta']['content'].strip()
yield resp
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
continue
try:
return response['choices'][0]['message']['content'].strip()
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
Expand All @@ -203,6 +246,8 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

print('Find error message in response: ',
str(response['error']))
except Exception as error:
print(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
Expand Down
7 changes: 6 additions & 1 deletion lagent/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import asdict, dataclass, field
from enum import IntEnum
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union


def enum_dict_factory(inputs):
Expand Down Expand Up @@ -77,12 +77,17 @@ class AgentStatusCode(IntEnum):
CODING = 6 # start python
CODE_END = 7 # end python
CODE_RETURN = 8 # python return
ANSWER_ING = 9 # final answer is in streaming


@dataclass
class AgentReturn:
type: str = ''
content: str = ''
state: Union[AgentStatusCode, int] = AgentStatusCode.END
actions: List[ActionReturn] = field(default_factory=list)
response: str = ''
inner_steps: List = field(default_factory=list)
nodes: Dict = None
adjacency_list: Dict = None
errmsg: Optional[str] = None
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
agentlego
google-search-results
lmdeploy>=0.2.3
pillow
Expand Down