Skip to content

Commit

Permalink
created custom error when env var for providers does not exist
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap committed Sep 30, 2024
1 parent 41f4e63 commit 4b207bf
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 51 deletions.
7 changes: 2 additions & 5 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import get_provider_env_value, retry_if_status
from exchange.providers.utils import raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
Expand All @@ -27,10 +27,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
try:
key = os.environ["ANTHROPIC_API_KEY"]
except KeyError:
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
client = httpx.Client(
base_url=url,
headers={
Expand Down
31 changes: 12 additions & 19 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,23 @@
import httpx

from exchange.providers import OpenAiProvider
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.utils import get_provider_env_value


class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service"""

def __init__(self, client: httpx.Client) -> None:
super().__init__(client)

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")

try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")
url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME")
deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")

try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")

try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY")

# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
Expand All @@ -43,3 +31,8 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

@classmethod
def _get_env_variable(cls:Type["AzureProvider"], key: str) -> str:
return get_provider_env_value(key, "azure")

12 changes: 11 additions & 1 deletion src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type

from exchange.message import Message
from exchange.tool import Tool
Expand Down Expand Up @@ -28,3 +28,13 @@ def complete(
) -> Tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass

class MissingProviderEnvVariableError(Exception):
def __init__(self, env_variable: str, provider: str, instructions: Optional[str] = None) -> None:
self.env_variable = env_variable
self.provider = provider
self.instructions = instructions
self.message = f"Missing environment variable: {env_variable} for provider {provider}"
if instructions:
self.message += f". {instructions}"
super().__init__(self.message)
15 changes: 8 additions & 7 deletions src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import get_provider_env_value, retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool

Expand Down Expand Up @@ -154,12 +154,9 @@ def __init__(self, client: AwsClient) -> None:
@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
aws_region = os.environ.get("AWS_REGION", "us-east-1")
try:
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
except KeyError:
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID")
aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY")
aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN")

client = AwsClient(
aws_region=aws_region,
Expand Down Expand Up @@ -326,3 +323,7 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]:
tools_added.add(tool.name)
tool_config = {"tools": tool_config_list}
return tool_config

@classmethod
def _get_env_variable(cls:Type["BedrockProvider"], key: str) -> str:
return get_provider_env_value(key, "bedrock")
21 changes: 8 additions & 13 deletions src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -37,18 +37,8 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
try:
url = os.environ["DATABRICKS_HOST"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
try:
key = os.environ["DATABRICKS_TOKEN"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
url = cls._get_env_variable("DATABRICKS_HOST")
key = cls._get_env_variable("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
auth=("token", key),
Expand Down Expand Up @@ -100,3 +90,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response:
json=payload,
)
return raise_for_status(response).json()

@classmethod
def _get_env_variable(cls:Type["DatabricksProvider"], key: str) -> str:
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
return get_provider_env_value(key, "databricks", instruction)
9 changes: 3 additions & 6 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
get_provider_env_value,
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
Expand Down Expand Up @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
api_key_instructions = "see https://platform.openai.com/docs/api-reference/api-keys"
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions)
client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),
Expand Down
7 changes: 7 additions & 0 deletions src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import json
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
from tenacity import retry_if_exception

Expand Down Expand Up @@ -178,6 +180,11 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
if code == "context_length_exceeded" or code == "string_above_max_length":
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")

def get_provider_env_value(env_variable: str, provider: str, instructions: Optional[str] = None) -> str:
try:
return os.environ[env_variable]
except KeyError:
raise MissingProviderEnvVariableError(env_variable, provider, instructions)

class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""
Expand Down
8 changes: 8 additions & 0 deletions tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.anthropic import AnthropicProvider
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool


Expand All @@ -24,6 +25,13 @@ def example_fn(param: str) -> None:
def anthropic_provider():
return AnthropicProvider.from_env()

def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
AnthropicProvider.from_env()
assert context.value.provider == "anthropic"
assert context.value.env_variable == "ANTHROPIC_API_KEY"
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic"

def test_anthropic_response_to_text_message() -> None:
response = {
Expand Down
24 changes: 24 additions & 0 deletions tests/providers/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
import os
from unittest.mock import patch

import pytest

from exchange import Text, ToolUse
from exchange.providers.azure import AzureProvider
from exchange.providers.base import MissingProviderEnvVariableError
from .conftest import complete, tools

AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")

@pytest.mark.parametrize(
"env_var_name",
[
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
("AZURE_CHAT_COMPLETIONS_KEY"),
]
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(os.environ, {
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
}, clear=True):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
AzureProvider.from_env()
assert context.value.provider == "azure"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure"

@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
Expand Down
21 changes: 21 additions & 0 deletions tests/providers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,32 @@
import pytest
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.bedrock import BedrockProvider
from exchange.tool import Tool

logger = logging.getLogger(__name__)

@pytest.mark.parametrize(
"env_var_name",
[
("AWS_ACCESS_KEY_ID"),
("AWS_SECRET_ACCESS_KEY"),
("AWS_SESSION_TOKEN"),
]
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(os.environ, {
"AWS_ACCESS_KEY_ID": "test_access_key_id",
"AWS_SECRET_ACCESS_KEY": "test_secret_access_key",
"AWS_SESSION_TOKEN": "test_session_token",
}, clear=True):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
BedrockProvider.from_env()
assert context.value.provider == "bedrock"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock"

@pytest.fixture
@patch.dict(
Expand Down
20 changes: 20 additions & 0 deletions tests/providers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,28 @@

import pytest
from exchange import Message, Text
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.databricks import DatabricksProvider

@pytest.mark.parametrize(
"env_var_name",
[
("DATABRICKS_HOST"),
("DATABRICKS_TOKEN"),
]
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(os.environ, {
"DATABRICKS_HOST": "test_host",
"DATABRICKS_TOKEN": "test_token",
}, clear=True):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
DatabricksProvider.from_env()
assert context.value.provider == "databricks"
assert context.value.env_variable == env_var_name
assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message
assert "https://docs.databricks.com" in context.value.message

@pytest.fixture
@patch.dict(
Expand Down
11 changes: 11 additions & 0 deletions tests/providers/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import os
from unittest.mock import patch

import pytest

from exchange import Text, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.openai import OpenAiProvider
from .conftest import complete, vision, tools

OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")


def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
OpenAiProvider.from_env()
assert context.value.provider == "openai"
assert context.value.env_variable == "OPENAI_API_KEY"
assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message
assert "https://platform.openai.com" in context.value.message

@pytest.mark.vcr()
def test_openai_complete(default_openai_env):
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)
Expand Down

0 comments on commit 4b207bf

Please sign in to comment.