From 329274e47bb1dc311c5735be3516406c30696c65 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 12:24:09 +1000 Subject: [PATCH] refactor instructions to instructions_url --- src/exchange/providers/base.py | 10 +++++----- src/exchange/providers/google.py | 4 ++-- src/exchange/providers/openai.py | 4 ++-- src/exchange/providers/utils.py | 4 ++-- tests/providers/test_anthropic.py | 2 +- tests/providers/test_azure.py | 2 +- tests/providers/test_bedrock.py | 2 +- tests/test_base.py | 27 +++++++++++++++++++++++++++ 8 files changed, 41 insertions(+), 14 deletions(-) create mode 100644 tests/test_base.py diff --git a/src/exchange/providers/base.py b/src/exchange/providers/base.py index 38bcf74..7ec8745 100644 --- a/src/exchange/providers/base.py +++ b/src/exchange/providers/base.py @@ -31,11 +31,11 @@ def complete( class MissingProviderEnvVariableError(Exception): - def __init__(self, env_variable: str, provider: str, instructions: Optional[str] = None) -> None: + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: self.env_variable = env_variable self.provider = provider - self.instructions = instructions - self.message = f"Missing environment variable: {env_variable} for provider {provider}" - if instructions: - self.message += f". {instructions}" + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\n Please see {instructions_url} for instructions" super().__init__(self.message) diff --git a/src/exchange/providers/google.py b/src/exchange/providers/google.py index b58ddb6..349b803 100644 --- a/src/exchange/providers/google.py +++ b/src/exchange/providers/google.py @@ -27,8 +27,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - api_key_instructions = "see https://ai.google.dev/gemini-api/docs/api-key" - key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions) + api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( base_url=url, headers={ diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1364c04..d921020 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -37,8 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - api_key_instructions = "see https://platform.openai.com/docs/api-reference/api-keys" - key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions) + api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index 7150598..0150430 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -181,11 +181,11 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") -def get_provider_env_value(env_variable: str, provider: str, instructions: Optional[str] = None) -> str: +def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: try: return os.environ[env_variable] except KeyError: - raise MissingProviderEnvVariableError(env_variable, provider, instructions) + raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) class InitialMessageTooLargeError(Exception): diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 5db3ec5..272ebcb 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -32,7 +32,7 @@ def test_from_env_throw_error_when_missing_api_key(): AnthropicProvider.from_env() assert context.value.provider == "anthropic" assert context.value.env_variable == "ANTHROPIC_API_KEY" - assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." def test_anthropic_response_to_text_message() -> None: diff --git a/tests/providers/test_azure.py b/tests/providers/test_azure.py index 9d4203f..b46be30 100644 --- a/tests/providers/test_azure.py +++ b/tests/providers/test_azure.py @@ -36,7 +36,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): AzureProvider.from_env() assert context.value.provider == "azure" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure" + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." @pytest.mark.vcr() diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index d31a738..f8fcaa4 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -35,7 +35,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): BedrockProvider.from_env() assert context.value.provider == "bedrock" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock" + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." @pytest.fixture diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..9b3d6d4 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + " Please see http://example.com/instructions for instructions" + )