-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add provide experimental support for Google Gemini #67
- Loading branch information
1 parent
41f4e63
commit 55ab355
Showing
5 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
from typing import Any, Dict, List, Tuple, Type | ||
|
||
import httpx | ||
|
||
from exchange import Message, Tool | ||
from exchange.content import Text, ToolResult, ToolUse | ||
from exchange.providers.base import Provider, Usage | ||
from tenacity import retry, wait_fixed, stop_after_attempt | ||
from exchange.providers.utils import retry_if_status | ||
from exchange.providers.utils import raise_for_status | ||
|
||
GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" | ||
|
||
retry_procedure = retry( | ||
wait=wait_fixed(2), | ||
stop=stop_after_attempt(2), | ||
retry=retry_if_status(codes=[429], above=500), | ||
reraise=True, | ||
) | ||
|
||
|
||
class GoogleProvider(Provider): | ||
def __init__(self, client: httpx.Client) -> None: | ||
self.client = client | ||
|
||
@classmethod | ||
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": | ||
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) | ||
try: | ||
key = os.environ["GOOGLE_API_KEY"] | ||
except KeyError: | ||
raise RuntimeError( | ||
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" | ||
) | ||
|
||
client = httpx.Client( | ||
base_url=url, | ||
headers={ | ||
"Content-Type": "application/json", | ||
}, | ||
params={"key": key}, | ||
timeout=httpx.Timeout(60 * 10), | ||
) | ||
return cls(client) | ||
|
||
@staticmethod | ||
def get_usage(data: Dict) -> Usage: # noqa: ANN401 | ||
usage = data.get("usageMetadata") | ||
input_tokens = usage.get("promptTokenCount") | ||
output_tokens = usage.get("candidatesTokenCount") | ||
total_tokens = usage.get("totalTokenCount") | ||
|
||
if total_tokens is None and input_tokens is not None and output_tokens is not None: | ||
total_tokens = input_tokens + output_tokens | ||
|
||
return Usage( | ||
input_tokens=input_tokens, | ||
output_tokens=output_tokens, | ||
total_tokens=total_tokens, | ||
) | ||
|
||
@staticmethod | ||
def google_response_to_message(response: Dict) -> Message: | ||
candidates = response.get("candidates", []) | ||
if candidates: | ||
# Only use first candidate for now | ||
candidate = candidates[0] | ||
content_parts = candidate.get("content", {}).get("parts", []) | ||
content = [] | ||
for part in content_parts: | ||
if "text" in part: | ||
content.append(Text(text=part["text"])) | ||
elif "functionCall" in part: | ||
content.append( | ||
ToolUse( | ||
id=part["functionCall"].get("name", ""), | ||
name=part["functionCall"].get("name", ""), | ||
parameters=part["functionCall"].get("args", {}), | ||
) | ||
) | ||
return Message(role="assistant", content=content) | ||
|
||
# If no valid candidates were found, return an empty message | ||
return Message(role="assistant", content=[]) | ||
|
||
@staticmethod | ||
def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: | ||
if not tools: | ||
return {} | ||
converted_tools = [] | ||
for tool in tools: | ||
converted_tool: Dict[str, Any] = { | ||
"name": tool.name, | ||
"description": tool.description or "", | ||
} | ||
if tool.parameters["properties"]: | ||
converted_tool["parameters"] = tool.parameters | ||
converted_tools.append(converted_tool) | ||
return {"functionDeclarations": converted_tools} | ||
|
||
@staticmethod | ||
def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: | ||
messages_spec = [] | ||
for message in messages: | ||
role = "user" if message.role == "user" else "model" | ||
converted = {"role": role, "parts": []} | ||
for content in message.content: | ||
if isinstance(content, Text): | ||
converted["parts"].append({"text": content.text}) | ||
elif isinstance(content, ToolUse): | ||
converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) | ||
elif isinstance(content, ToolResult): | ||
converted["parts"].append( | ||
{"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} | ||
) | ||
messages_spec.append(converted) | ||
|
||
if not messages_spec: | ||
messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]}) | ||
|
||
return messages_spec | ||
|
||
def complete( | ||
self, | ||
model: str, | ||
system: str, | ||
messages: List[Message], | ||
tools: List[Tool] = [], | ||
**kwargs: Dict[str, Any], | ||
) -> Tuple[Message, Usage]: | ||
tools_set = set() | ||
unique_tools = [] | ||
for tool in tools: | ||
if tool.name not in tools_set: | ||
unique_tools.append(tool) | ||
tools_set.add(tool.name) | ||
|
||
payload = dict( | ||
system_instruction={"parts": [{"text": system}]}, | ||
contents=self.messages_to_google_spec(messages), | ||
tools=self.tools_to_google_spec(tuple(unique_tools)), | ||
**kwargs, | ||
) | ||
payload = {k: v for k, v in payload.items() if v} | ||
response = self._post(payload, model) | ||
message = self.google_response_to_message(response) | ||
usage = self.get_usage(response) | ||
return message, usage | ||
|
||
@retry_procedure | ||
def _post(self, payload: dict, model: str) -> httpx.Response: | ||
response = self.client.post("models/" + model + ":generateContent", json=payload) | ||
return raise_for_status(response).json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import os | ||
from unittest.mock import patch | ||
|
||
import httpx | ||
import pytest | ||
from exchange import Message, Text | ||
from exchange.content import ToolResult, ToolUse | ||
from exchange.providers.google import GoogleProvider | ||
from exchange.tool import Tool | ||
|
||
|
||
def example_fn(param: str) -> None: | ||
""" | ||
Testing function. | ||
Args: | ||
param (str): Description of param1 | ||
""" | ||
pass | ||
|
||
|
||
@pytest.fixture | ||
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) | ||
def google_provider(): | ||
return GoogleProvider.from_env() | ||
|
||
|
||
def test_google_response_to_text_message() -> None: | ||
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} | ||
message = GoogleProvider.google_response_to_message(response) | ||
assert message.content[0].text == "Hello from Gemini!" | ||
|
||
|
||
def test_google_response_to_tool_use_message() -> None: | ||
response = { | ||
"candidates": [ | ||
{ | ||
"content": { | ||
"parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}], | ||
"role": "model", | ||
} | ||
} | ||
] | ||
} | ||
|
||
message = GoogleProvider.google_response_to_message(response) | ||
assert message.content[0].name == "example_fn" | ||
assert message.content[0].parameters == {"param": "value"} | ||
|
||
|
||
def test_tools_to_google_spec() -> None: | ||
tools = (Tool.from_function(example_fn),) | ||
expected_spec = { | ||
"functionDeclarations": [ | ||
{ | ||
"name": "example_fn", | ||
"description": "Testing function.", | ||
"parameters": { | ||
"type": "object", | ||
"properties": {"param": {"type": "string", "description": "Description of param1"}}, | ||
"required": ["param"], | ||
}, | ||
} | ||
] | ||
} | ||
result = GoogleProvider.tools_to_google_spec(tools) | ||
assert result == expected_spec | ||
|
||
|
||
def test_message_text_to_google_spec() -> None: | ||
messages = [Message.user("Hello, Gemini")] | ||
expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}] | ||
result = GoogleProvider.messages_to_google_spec(messages) | ||
assert result == expected_spec | ||
|
||
|
||
def test_messages_to_google_spec() -> None: | ||
messages = [ | ||
Message(role="user", content=[Text(text="Hello, Gemini")]), | ||
Message( | ||
role="assistant", | ||
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], | ||
), | ||
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), | ||
] | ||
actual_spec = GoogleProvider.messages_to_google_spec(messages) | ||
# != | ||
expected_spec = [ | ||
{"role": "user", "parts": [{"text": "Hello, Gemini"}]}, | ||
{"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]}, | ||
{"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]}, | ||
] | ||
|
||
assert actual_spec == expected_spec | ||
|
||
|
||
@patch("httpx.Client.post") | ||
@patch("logging.warning") | ||
@patch("logging.error") | ||
def test_google_completion(mock_error, mock_warning, mock_post, google_provider): | ||
mock_response = { | ||
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], | ||
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, | ||
} | ||
|
||
# First attempts fail with status code 429, 2nd succeeds | ||
def create_response(status_code, json_data=None): | ||
response = httpx.Response(status_code) | ||
response._content = httpx._content.json_dumps(json_data or {}).encode() | ||
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") | ||
return response | ||
|
||
mock_post.side_effect = [ | ||
create_response(429), # 1st attempt | ||
create_response(200, mock_response), # Final success | ||
] | ||
|
||
model = "gemini-1.5-flash" | ||
system = "You are a helpful assistant." | ||
messages = [Message.user("Hello, Gemini")] | ||
|
||
reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) | ||
|
||
assert reply_message.content == [Text(text="Hello from Gemini!")] | ||
assert reply_usage.total_tokens == 13 | ||
assert mock_post.call_count == 2 | ||
mock_post.assert_any_call( | ||
"models/gemini-1.5-flash:generateContent", | ||
json={ | ||
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, | ||
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.integration | ||
def test_google_integration(): | ||
provider = GoogleProvider.from_env() | ||
model = "gemini-1.5-flash" # updated model to a known valid model | ||
system = "You are a helpful assistant." | ||
messages = [Message.user("Hello, Gemini")] | ||
|
||
# Run the completion | ||
reply = provider.complete(model=model, system=system, messages=messages) | ||
|
||
assert reply[0].content is not None | ||
print("Completion content from Google:", reply[0].content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters