Skip to content

Commit

Permalink
test: convert Google Gemini tests to VCR
Browse files Browse the repository at this point in the history
part of square/exchange#67
closes square/exchange#71

Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt committed Oct 4, 2024
1 parent 499f37f commit 197f0dd
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"text\": \"Hello! \U0001F44B How can I help
you today? \U0001F60A \\n\"\n }\n ],\n \"role\": \"model\"\n
\ },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
8,\n \"candidatesTokenCount\": 12,\n \"totalTokenCount\": 20\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:50 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=426
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '855'
status:
code: 200
message: OK
version: 1
73 changes: 73 additions & 0 deletions packages/exchange/tests/providers/cassettes/test_google_tools.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant.
Expect to need to read a file using read_file."}]}, "contents": [{"role": "user",
"parts": [{"text": "What are the contents of this file? test.txt"}]}], "tools":
{"functionDeclarations": [{"name": "read_file", "description": "Read the contents
of the file.", "parameters": {"type": "object", "properties": {"filename": {"type":
"string", "description": "The path to the file, which can be relative or\nabsolute.
If it is a plain filename, it is assumed to be in the\ncurrent working directory."}},
"required": ["filename"]}}]}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '600'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"functionCall\": {\n \"name\": \"read_file\",\n
\ \"args\": {\n \"filename\": \"test.txt\"\n }\n
\ }\n }\n ],\n \"role\": \"model\"\n },\n
\ \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
101,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 118\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:51 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=449
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '947'
status:
code: 200
message: OK
version: 1
29 changes: 25 additions & 4 deletions packages/exchange/tests/providers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def default_azure_env(monkeypatch):
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)


GOOGLE_API_KEY = "test_google_api_key"


@pytest.fixture
def default_google_env(monkeypatch):
"""
This fixture prevents GoogleProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "GOOGLE_API_KEY" not in os.environ:
monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY)


@pytest.fixture(scope="module")
def vcr_config():
"""
Expand Down Expand Up @@ -85,6 +102,8 @@ def scrub_request_url(request):
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY
elif "generativelanguage.googleapis.com" in request.uri:
request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri)

return request

Expand All @@ -93,16 +112,18 @@ def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
"""
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
if "openai-organization" in response["headers"]:
response["headers"]["openai-organization"] = OPENAI_ORG_ID
if "Set-Cookie" in response["headers"]:
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response


def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)


def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
Expand All @@ -128,4 +149,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
),
]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
80 changes: 31 additions & 49 deletions packages/exchange/tests/providers/test_google.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from unittest.mock import patch

import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool
from .conftest import complete, tools

GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")


def example_fn(param: str) -> None:
Expand All @@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key():
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message


@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():
return GoogleProvider.from_env()


def test_google_response_to_text_message() -> None:
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
message = GoogleProvider.google_response_to_message(response)
Expand Down Expand Up @@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None:
assert actual_spec == expected_spec


@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
mock_response = {
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
}
@pytest.mark.vcr()
def test_google_complete(default_google_env):
reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL)

# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
return response
assert reply_message.content == [Text("Hello! 👋 How can I help you today? 😊 \n")]
assert reply_usage.total_tokens == 20

mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]

model = "gemini-1.5-flash"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
@pytest.mark.integration
def test_google_complete_integration():
reply = complete(GoogleProvider, GOOGLE_MODEL)

reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)
assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)

assert reply_message.content == [Text(text="Hello from Gemini!")]
assert reply_usage.total_tokens == 13
assert mock_post.call_count == 2
mock_post.assert_any_call(
"models/gemini-1.5-flash:generateContent",
json={
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
},
)

@pytest.mark.vcr()
def test_google_tools(default_google_env):
reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL)

@pytest.mark.integration
def test_google_integration():
provider = GoogleProvider.from_env()
model = "gemini-1.5-flash" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "read_file"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 118

# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)

assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)
@pytest.mark.integration
def test_google_tools_integration():
reply = tools(GoogleProvider, GOOGLE_MODEL)

tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}

0 comments on commit 197f0dd

Please sign in to comment.