From 32d345841c2e8a07cafbbb9d19f78bef596a6efd Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Wed, 5 Jun 2024 18:55:06 +0200 Subject: [PATCH] support for multiple function calls (#276) --- .../langchain_google_vertexai/chat_models.py | 61 ++++++++++--- .../tests/unit_tests/test_chat_models.py | 85 +++++++++++++++++++ 2 files changed, 136 insertions(+), 10 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index c6f6713e..9d7a0974 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -234,8 +234,12 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: system_parts: List[Part] | None = None system_instruction = None + # the last AI Message before a sequence of tool calls + prev_ai_message: Optional[AIMessage] = None + for i, message in enumerate(history): if isinstance(message, SystemMessage): + prev_ai_message = None if i != 0: raise ValueError("SystemMessage should be the first in the history.") if system_instruction is not None: @@ -254,6 +258,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: continue system_instruction = Content(role="user", parts=_convert_to_parts(message)) elif isinstance(message, HumanMessage): + prev_ai_message = None role = "user" parts = _convert_to_parts(message) if system_parts is not None: @@ -265,6 +270,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: system_parts = None vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, AIMessage): + prev_ai_message = message role = "model" parts = [] @@ -282,6 +288,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, FunctionMessage): + prev_ai_message = None role = "function" part = Part( @@ -308,13 +315,17 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: # message.name can be null for ToolMessage name = message.name if name is None: - prev_message = history[i - 1] if i > 0 else None - if isinstance(prev_message, AIMessage): + if prev_ai_message: tool_call_id = message.tool_call_id tool_call: ToolCall | None = next( - (t for t in prev_message.tool_calls if t["id"] == tool_call_id), + ( + t + for t in prev_ai_message.tool_calls + if t["id"] == tool_call_id + ), None, ) + if tool_call is None: raise ValueError( ( @@ -323,12 +334,43 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: ) ) name = tool_call["name"] + + def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]: + if isinstance(raw_content, dict): + return raw_content + if isinstance(raw_content, str): + try: + content = json.loads(raw_content) + # json.loads("2") returns 2 since it's a valid json + if isinstance(content, dict): + return content + except json.JSONDecodeError: + pass + return {"content": raw_content} + + if isinstance(message.content, list): + parsed_content = [_parse_content(c) for c in message.content] + if len(parsed_content) > 1: + merged_content: Dict[Any, Any] = {} + for content_piece in parsed_content: + for key, value in content_piece.items(): + if key not in merged_content: + merged_content[key] = [] + merged_content[key].append(value) + logger.warning( + "Expected content to be a str, got a list with > 1 element." + "Merging values together" + ) + content = {k: "".join(v) for k, v in merged_content.items()} + else: + content = parsed_content[0] + else: + content = _parse_content(message.content) + part = Part( function_response=FunctionResponse( name=name, - response={ - "content": message.content, - }, + response=content, ) ) @@ -340,10 +382,9 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: # replacing last message vertex_messages[-1] = Content(role=role, parts=parts) continue - - parts = [part] - - vertex_messages.append(Content(role=role, parts=parts)) + else: + parts = [part] + vertex_messages.append(Content(role=role, parts=parts)) else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 0038b2de..8fde4053 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -898,3 +898,88 @@ def test_safety_settings_gemini_init() -> None: ) safety_settings = model._safety_settings_gemini(None) assert safety_settings == expected_safety_setting + + +def test_multiple_fc() -> None: + prompt = ( + "I'm trying to decide whether to go to London or Zurich this weekend. How " + "hot are those cities? How about Singapore? Or maybe Tokyo. I want to go " + "somewhere not that cold but not too hot either. Suggest me." + ) + raw_history = [ + HumanMessage(content=prompt), + AIMessage( + content="", + tool_calls=[ + {"name": "get_weather", "args": {"location": "Munich"}, "id": "1"}, + {"name": "get_weather", "args": {"location": "London"}, "id": "2"}, + {"name": "get_weather", "args": {"location": "Berlin"}, "id": "3"}, + ], + ), + ToolMessage( + name="get_weather", + tool_call_id="1", + content='{"condition": "sunny", "temp_c": -23.9}', + ), + ToolMessage( + name="get_weather", + tool_call_id="2", + content='{"condition": "sunny", "temp_c": -30.0}', + ), + ToolMessage( + name="get_weather", + tool_call_id="3", + content='{"condition": "rainy", "temp_c": 25.2}', + ), + ] + _, history = _parse_chat_history_gemini(raw_history) + expected = [ + Content( + parts=[Part(text=prompt)], + role="user", + ), + Content( + parts=[ + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "Munich"} + ) + ), + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "London"} + ) + ), + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "Berlin"} + ) + ), + ], + role="model", + ), + Content( + parts=[ + Part( + function_response=FunctionResponse( + name="get_weather", + response={"condition": "sunny", "temp_c": -23.9}, + ) + ), + Part( + function_response=FunctionResponse( + name="get_weather", + response={"condition": "sunny", "temp_c": -30.0}, + ) + ), + Part( + function_response=FunctionResponse( + name="get_weather", + response={"condition": "rainy", "temp_c": 25.2}, + ) + ), + ], + role="function", + ), + ] + assert history == expected