Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add provide experimental support for Google Gemini #67 #68

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ databricks = "exchange.providers.databricks:DatabricksProvider"
anthropic = "exchange.providers.anthropic:AnthropicProvider"
bedrock = "exchange.providers.bedrock:BedrockProvider"
ollama = "exchange.providers.ollama:OllamaProvider"
google = "exchange.providers.google:GoogleProvider"

[project.entry-points."exchange.moderator"]
passive = "exchange.moderators.passive:PassiveModerator"
Expand Down
1 change: 1 addition & 0 deletions src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from exchange.providers.openai import OpenAiProvider # noqa
from exchange.providers.ollama import OllamaProvider # noqa
from exchange.providers.azure import AzureProvider # noqa
from exchange.providers.google import GoogleProvider # noqa

from exchange.utils import load_plugins

Expand Down
154 changes: 154 additions & 0 deletions src/exchange/providers/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
from typing import Any, Dict, List, Tuple, Type

import httpx

from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status

GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class GoogleProvider(Provider):
def __init__(self, client: httpx.Client) -> None:
self.client = client

@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
try:
key = os.environ["GOOGLE_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
)

client = httpx.Client(
base_url=url,
headers={
"Content-Type": "application/json",
},
params={"key": key},
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
usage = data.get("usageMetadata")
input_tokens = usage.get("promptTokenCount")
output_tokens = usage.get("candidatesTokenCount")
total_tokens = usage.get("totalTokenCount")

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)

@staticmethod
def google_response_to_message(response: Dict) -> Message:
candidates = response.get("candidates", [])
if candidates:
# Only use first candidate for now
candidate = candidates[0]
content_parts = candidate.get("content", {}).get("parts", [])
content = []
for part in content_parts:
if "text" in part:
content.append(Text(text=part["text"]))
elif "functionCall" in part:
content.append(
ToolUse(
id=part["functionCall"].get("name", ""),
name=part["functionCall"].get("name", ""),
parameters=part["functionCall"].get("args", {}),
)
)
return Message(role="assistant", content=content)

# If no valid candidates were found, return an empty message
return Message(role="assistant", content=[])

@staticmethod
def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]:
if not tools:
return {}
converted_tools = []
for tool in tools:
converted_tool: Dict[str, Any] = {
"name": tool.name,
"description": tool.description or "",
}
if tool.parameters["properties"]:
converted_tool["parameters"] = tool.parameters
converted_tools.append(converted_tool)
return {"functionDeclarations": converted_tools}

@staticmethod
def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]:
messages_spec = []
for message in messages:
role = "user" if message.role == "user" else "model"
converted = {"role": role, "parts": []}
for content in message.content:
if isinstance(content, Text):
converted["parts"].append({"text": content.text})
elif isinstance(content, ToolUse):
converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}})
elif isinstance(content, ToolResult):
converted["parts"].append(
{"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}}
)
messages_spec.append(converted)

if not messages_spec:
messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]})

return messages_spec

def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
tools_set = set()
unique_tools = []
for tool in tools:
if tool.name not in tools_set:
unique_tools.append(tool)
tools_set.add(tool.name)

payload = dict(
system_instruction={"parts": [{"text": system}]},
contents=self.messages_to_google_spec(messages),
tools=self.tools_to_google_spec(tuple(unique_tools)),
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(payload, model)
message = self.google_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_procedure
def _post(self, payload: dict, model: str) -> httpx.Response:
response = self.client.post("models/" + model + ":generateContent", json=payload)
return raise_for_status(response).json()
147 changes: 147 additions & 0 deletions tests/providers/test_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
from unittest.mock import patch

import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool


def example_fn(param: str) -> None:
"""
Testing function.

Args:
param (str): Description of param1
"""
pass


@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():
return GoogleProvider.from_env()


def test_google_response_to_text_message() -> None:
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
message = GoogleProvider.google_response_to_message(response)
assert message.content[0].text == "Hello from Gemini!"


def test_google_response_to_tool_use_message() -> None:
response = {
"candidates": [
{
"content": {
"parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}],
"role": "model",
}
}
]
}

message = GoogleProvider.google_response_to_message(response)
assert message.content[0].name == "example_fn"
assert message.content[0].parameters == {"param": "value"}


def test_tools_to_google_spec() -> None:
tools = (Tool.from_function(example_fn),)
expected_spec = {
"functionDeclarations": [
{
"name": "example_fn",
"description": "Testing function.",
"parameters": {
"type": "object",
"properties": {"param": {"type": "string", "description": "Description of param1"}},
"required": ["param"],
},
}
]
}
result = GoogleProvider.tools_to_google_spec(tools)
assert result == expected_spec


def test_message_text_to_google_spec() -> None:
messages = [Message.user("Hello, Gemini")]
expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}]
result = GoogleProvider.messages_to_google_spec(messages)
assert result == expected_spec


def test_messages_to_google_spec() -> None:
messages = [
Message(role="user", content=[Text(text="Hello, Gemini")]),
Message(
role="assistant",
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]),
]
actual_spec = GoogleProvider.messages_to_google_spec(messages)
# !=
expected_spec = [
{"role": "user", "parts": [{"text": "Hello, Gemini"}]},
{"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]},
{"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]},
]

assert actual_spec == expected_spec


@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
mock_response = {
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
}

# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
return response

mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]

model = "gemini-1.5-flash"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]

reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)

assert reply_message.content == [Text(text="Hello from Gemini!")]
assert reply_usage.total_tokens == 13
assert mock_post.call_count == 2
mock_post.assert_any_call(
"models/gemini-1.5-flash:generateContent",
json={
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
},
)


@pytest.mark.integration
def test_google_integration():
provider = GoogleProvider.from_env()
model = "gemini-1.5-flash" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]

# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)

assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)
1 change: 1 addition & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
(get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
(get_provider("google"), "gemini-1.5-flash", dict()),
]


Expand Down