Skip to content

Commit

Permalink
support for multiple function calls (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Jun 5, 2024
1 parent 1cba20a commit 32d3458
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 10 deletions.
61 changes: 51 additions & 10 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,12 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
system_parts: List[Part] | None = None
system_instruction = None

# the last AI Message before a sequence of tool calls
prev_ai_message: Optional[AIMessage] = None

for i, message in enumerate(history):
if isinstance(message, SystemMessage):
prev_ai_message = None
if i != 0:
raise ValueError("SystemMessage should be the first in the history.")
if system_instruction is not None:
Expand All @@ -254,6 +258,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
continue
system_instruction = Content(role="user", parts=_convert_to_parts(message))
elif isinstance(message, HumanMessage):
prev_ai_message = None
role = "user"
parts = _convert_to_parts(message)
if system_parts is not None:
Expand All @@ -265,6 +270,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
system_parts = None
vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, AIMessage):
prev_ai_message = message
role = "model"

parts = []
Expand All @@ -282,6 +288,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, FunctionMessage):
prev_ai_message = None
role = "function"

part = Part(
Expand All @@ -308,13 +315,17 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
# message.name can be null for ToolMessage
name = message.name
if name is None:
prev_message = history[i - 1] if i > 0 else None
if isinstance(prev_message, AIMessage):
if prev_ai_message:
tool_call_id = message.tool_call_id
tool_call: ToolCall | None = next(
(t for t in prev_message.tool_calls if t["id"] == tool_call_id),
(
t
for t in prev_ai_message.tool_calls
if t["id"] == tool_call_id
),
None,
)

if tool_call is None:
raise ValueError(
(
Expand All @@ -323,12 +334,43 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
)
)
name = tool_call["name"]

def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]:
if isinstance(raw_content, dict):
return raw_content
if isinstance(raw_content, str):
try:
content = json.loads(raw_content)
# json.loads("2") returns 2 since it's a valid json
if isinstance(content, dict):
return content
except json.JSONDecodeError:
pass
return {"content": raw_content}

if isinstance(message.content, list):
parsed_content = [_parse_content(c) for c in message.content]
if len(parsed_content) > 1:
merged_content: Dict[Any, Any] = {}
for content_piece in parsed_content:
for key, value in content_piece.items():
if key not in merged_content:
merged_content[key] = []
merged_content[key].append(value)
logger.warning(
"Expected content to be a str, got a list with > 1 element."
"Merging values together"
)
content = {k: "".join(v) for k, v in merged_content.items()}
else:
content = parsed_content[0]
else:
content = _parse_content(message.content)

part = Part(
function_response=FunctionResponse(
name=name,
response={
"content": message.content,
},
response=content,
)
)

Expand All @@ -340,10 +382,9 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
# replacing last message
vertex_messages[-1] = Content(role=role, parts=parts)
continue

parts = [part]

vertex_messages.append(Content(role=role, parts=parts))
else:
parts = [part]
vertex_messages.append(Content(role=role, parts=parts))
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand Down
85 changes: 85 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,88 @@ def test_safety_settings_gemini_init() -> None:
)
safety_settings = model._safety_settings_gemini(None)
assert safety_settings == expected_safety_setting


def test_multiple_fc() -> None:
prompt = (
"I'm trying to decide whether to go to London or Zurich this weekend. How "
"hot are those cities? How about Singapore? Or maybe Tokyo. I want to go "
"somewhere not that cold but not too hot either. Suggest me."
)
raw_history = [
HumanMessage(content=prompt),
AIMessage(
content="",
tool_calls=[
{"name": "get_weather", "args": {"location": "Munich"}, "id": "1"},
{"name": "get_weather", "args": {"location": "London"}, "id": "2"},
{"name": "get_weather", "args": {"location": "Berlin"}, "id": "3"},
],
),
ToolMessage(
name="get_weather",
tool_call_id="1",
content='{"condition": "sunny", "temp_c": -23.9}',
),
ToolMessage(
name="get_weather",
tool_call_id="2",
content='{"condition": "sunny", "temp_c": -30.0}',
),
ToolMessage(
name="get_weather",
tool_call_id="3",
content='{"condition": "rainy", "temp_c": 25.2}',
),
]
_, history = _parse_chat_history_gemini(raw_history)
expected = [
Content(
parts=[Part(text=prompt)],
role="user",
),
Content(
parts=[
Part(
function_call=FunctionCall(
name="get_weather", args={"location": "Munich"}
)
),
Part(
function_call=FunctionCall(
name="get_weather", args={"location": "London"}
)
),
Part(
function_call=FunctionCall(
name="get_weather", args={"location": "Berlin"}
)
),
],
role="model",
),
Content(
parts=[
Part(
function_response=FunctionResponse(
name="get_weather",
response={"condition": "sunny", "temp_c": -23.9},
)
),
Part(
function_response=FunctionResponse(
name="get_weather",
response={"condition": "sunny", "temp_c": -30.0},
)
),
Part(
function_response=FunctionResponse(
name="get_weather",
response={"condition": "rainy", "temp_c": 25.2},
)
),
],
role="function",
),
]
assert history == expected

0 comments on commit 32d3458

Please sign in to comment.