From 55ab35560c1212cb637716c4097d7e18f4af9cff Mon Sep 17 00:00:00 2001 From: Shane Huntley Date: Mon, 30 Sep 2024 08:27:43 +1000 Subject: [PATCH] feat: Add provide experimental support for Google Gemini #67 --- pyproject.toml | 1 + src/exchange/providers/__init__.py | 1 + src/exchange/providers/google.py | 154 +++++++++++++++++++++++++++++ tests/providers/test_google.py | 147 +++++++++++++++++++++++++++ tests/test_integration.py | 1 + 5 files changed, 304 insertions(+) create mode 100644 src/exchange/providers/google.py create mode 100644 tests/providers/test_google.py diff --git a/pyproject.toml b/pyproject.toml index eacca69..83a9e3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ databricks = "exchange.providers.databricks:DatabricksProvider" anthropic = "exchange.providers.anthropic:AnthropicProvider" bedrock = "exchange.providers.bedrock:BedrockProvider" ollama = "exchange.providers.ollama:OllamaProvider" +google = "exchange.providers.google:GoogleProvider" [project.entry-points."exchange.moderator"] passive = "exchange.moderators.passive:PassiveModerator" diff --git a/src/exchange/providers/__init__.py b/src/exchange/providers/__init__.py index 177ea63..ac7ed07 100644 --- a/src/exchange/providers/__init__.py +++ b/src/exchange/providers/__init__.py @@ -7,6 +7,7 @@ from exchange.providers.openai import OpenAiProvider # noqa from exchange.providers.ollama import OllamaProvider # noqa from exchange.providers.azure import AzureProvider # noqa +from exchange.providers.google import GoogleProvider # noqa from exchange.utils import load_plugins diff --git a/src/exchange/providers/google.py b/src/exchange/providers/google.py new file mode 100644 index 0000000..426aa79 --- /dev/null +++ b/src/exchange/providers/google.py @@ -0,0 +1,154 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange import Message, Tool +from exchange.content import Text, ToolResult, ToolUse +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status +from exchange.providers.utils import raise_for_status + +GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class GoogleProvider(Provider): + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) + try: + key = os.environ["GOOGLE_API_KEY"] + except KeyError: + raise RuntimeError( + "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" + ) + + client = httpx.Client( + base_url=url, + headers={ + "Content-Type": "application/json", + }, + params={"key": key}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: Dict) -> Usage: # noqa: ANN401 + usage = data.get("usageMetadata") + input_tokens = usage.get("promptTokenCount") + output_tokens = usage.get("candidatesTokenCount") + total_tokens = usage.get("totalTokenCount") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @staticmethod + def google_response_to_message(response: Dict) -> Message: + candidates = response.get("candidates", []) + if candidates: + # Only use first candidate for now + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + content = [] + for part in content_parts: + if "text" in part: + content.append(Text(text=part["text"])) + elif "functionCall" in part: + content.append( + ToolUse( + id=part["functionCall"].get("name", ""), + name=part["functionCall"].get("name", ""), + parameters=part["functionCall"].get("args", {}), + ) + ) + return Message(role="assistant", content=content) + + # If no valid candidates were found, return an empty message + return Message(role="assistant", content=[]) + + @staticmethod + def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: + if not tools: + return {} + converted_tools = [] + for tool in tools: + converted_tool: Dict[str, Any] = { + "name": tool.name, + "description": tool.description or "", + } + if tool.parameters["properties"]: + converted_tool["parameters"] = tool.parameters + converted_tools.append(converted_tool) + return {"functionDeclarations": converted_tools} + + @staticmethod + def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + for message in messages: + role = "user" if message.role == "user" else "model" + converted = {"role": role, "parts": []} + for content in message.content: + if isinstance(content, Text): + converted["parts"].append({"text": content.text}) + elif isinstance(content, ToolUse): + converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) + elif isinstance(content, ToolResult): + converted["parts"].append( + {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} + ) + messages_spec.append(converted) + + if not messages_spec: + messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]}) + + return messages_spec + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: List[Tool] = [], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + tools_set = set() + unique_tools = [] + for tool in tools: + if tool.name not in tools_set: + unique_tools.append(tool) + tools_set.add(tool.name) + + payload = dict( + system_instruction={"parts": [{"text": system}]}, + contents=self.messages_to_google_spec(messages), + tools=self.tools_to_google_spec(tuple(unique_tools)), + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload, model) + message = self.google_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict, model: str) -> httpx.Response: + response = self.client.post("models/" + model + ":generateContent", json=payload) + return raise_for_status(response).json() diff --git a/tests/providers/test_google.py b/tests/providers/test_google.py new file mode 100644 index 0000000..47ad46b --- /dev/null +++ b/tests/providers/test_google.py @@ -0,0 +1,147 @@ +import os +from unittest.mock import patch + +import httpx +import pytest +from exchange import Message, Text +from exchange.content import ToolResult, ToolUse +from exchange.providers.google import GoogleProvider +from exchange.tool import Tool + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +@pytest.fixture +@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) +def google_provider(): + return GoogleProvider.from_env() + + +def test_google_response_to_text_message() -> None: + response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].text == "Hello from Gemini!" + + +def test_google_response_to_tool_use_message() -> None: + response = { + "candidates": [ + { + "content": { + "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}], + "role": "model", + } + } + ] + } + + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].name == "example_fn" + assert message.content[0].parameters == {"param": "value"} + + +def test_tools_to_google_spec() -> None: + tools = (Tool.from_function(example_fn),) + expected_spec = { + "functionDeclarations": [ + { + "name": "example_fn", + "description": "Testing function.", + "parameters": { + "type": "object", + "properties": {"param": {"type": "string", "description": "Description of param1"}}, + "required": ["param"], + }, + } + ] + } + result = GoogleProvider.tools_to_google_spec(tools) + assert result == expected_spec + + +def test_message_text_to_google_spec() -> None: + messages = [Message.user("Hello, Gemini")] + expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}] + result = GoogleProvider.messages_to_google_spec(messages) + assert result == expected_spec + + +def test_messages_to_google_spec() -> None: + messages = [ + Message(role="user", content=[Text(text="Hello, Gemini")]), + Message( + role="assistant", + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), + ] + actual_spec = GoogleProvider.messages_to_google_spec(messages) + # != + expected_spec = [ + {"role": "user", "parts": [{"text": "Hello, Gemini"}]}, + {"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]}, + {"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]}, + ] + + assert actual_spec == expected_spec + + +@patch("httpx.Client.post") +@patch("logging.warning") +@patch("logging.error") +def test_google_completion(mock_error, mock_warning, mock_post, google_provider): + mock_response = { + "candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], + "usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, + } + + # First attempts fail with status code 429, 2nd succeeds + def create_response(status_code, json_data=None): + response = httpx.Response(status_code) + response._content = httpx._content.json_dumps(json_data or {}).encode() + response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") + return response + + mock_post.side_effect = [ + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success + ] + + model = "gemini-1.5-flash" + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) + + assert reply_message.content == [Text(text="Hello from Gemini!")] + assert reply_usage.total_tokens == 13 + assert mock_post.call_count == 2 + mock_post.assert_any_call( + "models/gemini-1.5-flash:generateContent", + json={ + "system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], + }, + ) + + +@pytest.mark.integration +def test_google_integration(): + provider = GoogleProvider.from_env() + model = "gemini-1.5-flash" # updated model to a known valid model + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + # Run the completion + reply = provider.complete(model=model, system=system, messages=messages) + + assert reply[0].content is not None + print("Completion content from Google:", reply[0].content) diff --git a/tests/test_integration.py b/tests/test_integration.py index 8554e8a..1eb1980 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -17,6 +17,7 @@ (get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()), (get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()), (get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()), + (get_provider("google"), "gemini-1.5-flash", dict()), ]