Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: reduce code redundancy in openai based tests #54

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
base_url=url,
timeout=httpx.Timeout(60 * 10),
)
# from_env is expected to fail if provider is not available
# so we run a quick test that the endpoint is running
client.get("")
# from_env is expected to fail if required ENV variables are not
# available. Since this provider can run with defaults, we substitute
# a health check to verify the endpoint is running.
client.get("/")
codefromthecrypt marked this conversation as resolved.
Show resolved Hide resolved
# The OpenAI API is defined after "v1/", so we need to join it here.
client.base_url = client.base_url.join("v1/")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i'd put this above on line 37

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill do a quick fix in the 0.9.3 release so we can get this in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do on next change thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it is subtle. the health check is ollama, so we need to test the base of ollama before appending "v1". It would be better if we didn't need to health check imho, especially as we aren't in any other provider. lemme know your thoughts on that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found a way that may work fine in #60

return cls(client)
8 changes: 6 additions & 2 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
client = httpx.Client(
base_url=url,
base_url=url + "v1/",
michaelneale marked this conversation as resolved.
Show resolved Hide resolved
auth=("Bearer", key),
timeout=httpx.Timeout(60 * 10),
)
Expand Down Expand Up @@ -93,5 +93,9 @@ def complete(

@retry_procedure
def _post(self, payload: dict) -> dict:
response = self.client.post("v1/chat/completions", json=payload)
# Note: While OpenAI and Ollama mount the API under "v1", this is
# conventional and not a strict requirement. For example, Azure OpenAI
# mounts the API under the deployment name, and "v1" is not in the URL.
# See https://github.com/openai/openai-openapi/blob/master/openapi.yaml
response = self.client.post("chat/completions", json=payload)
return raise_for_status(response).json()
24 changes: 17 additions & 7 deletions tests/providers/openai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
from typing import Type, Tuple

import pytest

OPENAI_MODEL = "gpt-4o-mini"
from exchange import Message
from exchange.providers import Usage, Provider

OPENAI_API_KEY = "test_openai_api_key"
OPENAI_ORG_ID = "test_openai_org_key"
OPENAI_PROJECT_ID = "test_openai_project_id"


@pytest.fixture
def default_openai_api_key(monkeypatch):
def default_openai_env(monkeypatch):
"""
This fixture avoids the error OpenAiProvider.from_env() raises when the
OPENAI_API_KEY is not set in the environment.
This fixture prevents OpenAIProvider.from_env() from erring on missing
environment variables.

When running VCR tests for the first time or after deleting a cassette
recording, a real OPENAI_API_KEY must be passed as an environment variable,
so real responses can be fetched. Subsequent runs use the recorded data, so
don't need a real key.
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't them.
"""
if "OPENAI_API_KEY" not in os.environ:
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)
Expand Down Expand Up @@ -50,3 +53,10 @@ def scrub_response_headers(response):
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response


def complete(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
24 changes: 8 additions & 16 deletions tests/providers/openai/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
from typing import Tuple

import os

import pytest

from exchange import Text
from exchange.message import Message
from exchange.providers.base import Usage
from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL
from .conftest import complete

OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)


@pytest.mark.vcr()
def test_ollama_completion(default_openai_api_key):
reply_message, reply_usage = ollama_complete()
def test_ollama_complete():
reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL)

assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")]
assert reply_usage.total_tokens == 33


@pytest.mark.integration
def test_ollama_completion_integration():
reply = ollama_complete()
def test_ollama_complete_integration():
reply = complete(OllamaProvider, OLLAMA_MODEL)

assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)


def ollama_complete() -> Tuple[Message, Usage]:
provider = OllamaProvider.from_env()
model = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
32 changes: 9 additions & 23 deletions tests/providers/openai/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
from typing import Tuple

import os

import pytest

from exchange import Text
from exchange.message import Message
from exchange.providers.base import Usage
from exchange.providers.openai import OpenAiProvider
from .conftest import OPENAI_MODEL, OPENAI_API_KEY
from .conftest import complete

OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")

@pytest.mark.vcr()
def test_openai_completion(monkeypatch):
# When running VCR tests the first time, it needs OPENAI_API_KEY to call
# the real service. Afterward, it is not needed as VCR mocks the service.
if "OPENAI_API_KEY" not in os.environ:
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)

reply_message, reply_usage = openai_complete()
@pytest.mark.vcr()
def test_openai_complete(default_openai_env):
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)

assert reply_message.content == [Text(text="Hello! How can I assist you today?")]
assert reply_usage.total_tokens == 27


@pytest.mark.integration
def test_openai_completion_integration():
reply = openai_complete()
def test_openai_complete_integration():
reply = complete(OpenAiProvider, OPENAI_MODEL)

assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)


def openai_complete() -> Tuple[Message, Usage]:
provider = OpenAiProvider.from_env()
model = OPENAI_MODEL
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
print("Complete content from OpenAI:", reply[0].content)
6 changes: 3 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
cases = [
# Set seed and temperature for more determinism, to avoid flakes
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)),
(get_provider("openai"), "gpt-4o-mini", dict()),
(get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini"), dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
]


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.integration
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_simple(provider, model, kwargs):
provider = provider.from_env()
Expand All @@ -39,7 +39,7 @@ def test_simple(provider, model, kwargs):
assert "gandalf" in response.text.lower()


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.integration
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tools(provider, model, kwargs, tmp_path):
provider = provider.from_env()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_integration_vision.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
from exchange.content import ToolResult, ToolUse
from exchange.exchange import Exchange
Expand All @@ -6,7 +8,7 @@
from exchange.providers import get_provider

cases = [
(get_provider("openai"), "gpt-4o-mini"),
(get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")),
]


Expand Down