From dd5595e69723f167211f196a07cc941fbef955b1 Mon Sep 17 00:00:00 2001 From: Bradley Axen Date: Thu, 12 Sep 2024 12:16:44 -0700 Subject: [PATCH] feat: Rework error handling The previous implementation was too inclusive in many cases, and also squashed the original error output which made it hard to identify problems in setup. This approach should let us support retrying flakey errors while not squashing real issues like access, incorrect model ids, etc Adding tenacity as a dep - it has an apache license, zero dependencies, and provides a great standard for error handling --- pyproject.toml | 1 + src/exchange/providers/anthropic.py | 25 ++- src/exchange/providers/azure.py | 29 +-- src/exchange/providers/bedrock.py | 26 ++- src/exchange/providers/databricks.py | 26 ++- src/exchange/providers/openai.py | 35 ++-- .../retry_with_back_off_decorator.py | 61 ------ src/exchange/providers/utils.py | 17 +- tests/providers/test_anthropic.py | 14 +- .../test_retry_with_back_off_decorator.py | 174 ------------------ 10 files changed, 110 insertions(+), 298 deletions(-) delete mode 100644 src/exchange/providers/retry_with_back_off_decorator.py delete mode 100644 tests/providers/test_retry_with_back_off_decorator.py diff --git a/pyproject.toml b/pyproject.toml index f4c1a85..797e53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "jinja2>=3.1.4", "tiktoken>=0.7.0", "httpx>=0.27.0", + "tenacity>=9.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/src/exchange/providers/anthropic.py b/src/exchange/providers/anthropic.py index 7319ee8..154ec5f 100644 --- a/src/exchange/providers/anthropic.py +++ b/src/exchange/providers/anthropic.py @@ -6,11 +6,19 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AnthropicProvider(Provider): def __init__(self, client: httpx.Client) -> None: @@ -138,14 +146,13 @@ def complete( ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(payload) - - response_data = raise_for_status(response).json() - message = self.anthropic_response_to_message(response_data) - usage = self.get_usage(response_data) + response = self._post(payload) + message = self.anthropic_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Dict[str, Any]) -> httpx.Response: - return self.client.post(ANTHROPIC_HOST, json=payload) + @retry_procedure + def _post(self, payload: dict) -> httpx.Response: + response = self.client.post(ANTHROPIC_HOST, json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index 48e2747..a2afc4c 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -5,7 +5,8 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -15,6 +16,13 @@ ) from exchange.tool import Tool +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AzureProvider(Provider): """Provides chat completions for models hosted directly by OpenAI""" @@ -91,18 +99,17 @@ def complete( payload = {k: v for k, v in payload.items() if v} request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}" - response = self._send_request(payload, request_url) + response = self._post(payload, request_url) # Check for context_length_exceeded error for single, long input message - if "error" in response.json() and len(messages) == 1: - openai_single_message_context_length_exceeded(response.json()["error"]) - - data = raise_for_status(response).json() + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) - message = openai_response_to_message(data) - usage = self.get_usage(data) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401 - return self.client.post(request_url, json=payload) + @retry_procedure + def _post(self, payload: Any, request_url: str) -> dict: + response = self.client.post(request_url, json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index c51f69e..2a5f53d 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -12,7 +12,8 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.providers import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import raise_for_status from exchange.tool import Tool @@ -21,6 +22,13 @@ logger = logging.getLogger(__name__) +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AwsClient(httpx.Client): def __init__( @@ -110,7 +118,7 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: algorithm = "AWS4-HMAC-SHA256" credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request" string_to_sign = ( - f'{algorithm}\n{amz_date}\n{credential_scope}\n' + f"{algorithm}\n{amz_date}\n{credential_scope}\n" f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}' ) @@ -204,11 +212,10 @@ def complete( payload = {k: v for k, v in payload.items() if v} path = f"{self.client.host}model/{model}/converse" - response = self._send_request(payload, path) - raise_for_status(response) - response_message = response.json()["output"]["message"] + response = self._post(payload, path) + response_message = response["output"]["message"] - usage_data = response.json()["usage"] + usage_data = response["usage"] usage = Usage( input_tokens=usage_data.get("inputTokens"), output_tokens=usage_data.get("outputTokens"), @@ -217,9 +224,10 @@ def complete( return self.response_to_message(response_message), usage - @retry_httpx_request() - def _send_request(self, payload: Any, path: str) -> httpx.Response: # noqa: ANN401 - return self.client.post(path, json=payload) + @retry_procedure + def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + response = self.client.post(path, json=payload) + return raise_for_status(response).json() @staticmethod def message_to_bedrock_spec(message: Message) -> dict: diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index f946c52..84dc751 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -5,16 +5,24 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, - raise_for_status, tools_to_openai_spec, ) from exchange.tool import Tool +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + class DatabricksProvider(Provider): """Provides chat completions for models on Databricks serving endpoints @@ -80,15 +88,15 @@ def complete( **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(model, payload) - data = raise_for_status(response).json() - message = openai_response_to_message(data) - usage = self.get_usage(data) + response = self._post(model, payload) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401 - return self.client.post( + @retry_procedure + def _post(self, model: str, payload: dict) -> httpx.Response: + response = self.client.post( f"serving-endpoints/{model}/invocations", json=payload, ) + return raise_for_status(response).json() diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1f3133b..1f56eff 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -5,7 +5,6 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -14,9 +13,18 @@ tools_to_openai_spec, ) from exchange.tool import Tool +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status OPENAI_HOST = "https://api.openai.com/" +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class OpenAiProvider(Provider): """Provides chat completions for models hosted directly by OpenAI""" @@ -65,28 +73,25 @@ def complete( tools: Tuple[Tool], **kwargs: Dict[str, Any], ) -> Tuple[Message, Usage]: + system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] payload = dict( - messages=[ - {"role": "system", "content": system}, - *messages_to_openai_spec(messages), - ], + messages=system_message + messages_to_openai_spec(messages), model=model, tools=tools_to_openai_spec(tools) if tools else [], **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(payload) + response = self._post(payload) # Check for context_length_exceeded error for single, long input message - if "error" in response.json() and len(messages) == 1: - openai_single_message_context_length_exceeded(response.json()["error"]) - - data = raise_for_status(response).json() + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) - message = openai_response_to_message(data) - usage = self.get_usage(data) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 - return self.client.post("v1/chat/completions", json=payload) + @retry_procedure + def _post(self, payload: Any) -> dict: + response = self.client.post("v1/chat/completions", json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/retry_with_back_off_decorator.py b/src/exchange/providers/retry_with_back_off_decorator.py deleted file mode 100644 index c90c3d8..0000000 --- a/src/exchange/providers/retry_with_back_off_decorator.py +++ /dev/null @@ -1,61 +0,0 @@ -import time -from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Optional - -from httpx import HTTPStatusError, Response - - -def retry_with_backoff( - should_retry: Callable, - max_retries: Optional[int] = 5, - initial_wait: Optional[float] = 10, - backoff_factor: Optional[float] = 1, - handle_retry_exhausted: Optional[Callable] = None, -) -> Callable: - def decorator(func: Callable) -> Callable: - @wraps(func) - def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401 - result = None - for retry in range(max_retries): - result = func(*args, **kwargs) - if not should_retry(result): - return result - if (retry + 1) == max_retries: - break - sleep_time = initial_wait + (backoff_factor * (2**retry)) - time.sleep(sleep_time) - if handle_retry_exhausted: - handle_retry_exhausted(result, max_retries) - return result - - return wrapper - - return decorator - - -def retry_httpx_request( - retry_on_status_code: Optional[Iterable[int]] = None, - max_retries: Optional[int] = 5, - initial_wait: Optional[float] = 10, - backoff_factor: Optional[float] = 1, -) -> Callable: - if retry_on_status_code is None: - retry_on_status_code = set(range(401, 999)) - - def should_retry(response: Response) -> bool: - return response.status_code in retry_on_status_code - - def handle_retry_exhausted(response: Response, max_retries: int) -> None: - raise HTTPStatusError( - f"Failed after {max_retries}.", - request=response.request, - response=response, - ) - - return retry_with_backoff( - max_retries=max_retries, - initial_wait=initial_wait, - backoff_factor=backoff_factor, - should_retry=should_retry, - handle_retry_exhausted=handle_retry_exhausted, - ) diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index f53cbfc..4be7ac3 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -1,12 +1,27 @@ import base64 import json import re -from typing import Any, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.tool import Tool +from tenacity import retry_if_exception + + +def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: + codes = codes or [] + + def predicate(exc: Exception) -> bool: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code in codes: + return True + if above and exc.response.status_code >= above: + return True + return False + + return retry_if_exception(predicate) def raise_for_status(response: httpx.Response) -> httpx.Response: diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 37ea051..a6f5bc6 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -112,16 +112,15 @@ def test_messages_to_anthropic_spec() -> None: @patch("httpx.Client.post") -@patch("time.sleep", return_value=None) @patch("logging.warning") @patch("logging.error") -def test_anthropic_completion(mock_error, mock_warning, mock_sleep, mock_post, anthropic_provider): +def test_anthropic_completion(mock_error, mock_warning, mock_post, anthropic_provider): mock_response = { "content": [{"type": "text", "text": "Hello from Claude!"}], "usage": {"input_tokens": 10, "output_tokens": 25}, } - # First 4 attempts fail with status code 429, 5th succeeds + # First attempts fail with status code 429, 2nd succeeds def create_response(status_code, json_data=None): response = httpx.Response(status_code) response._content = httpx._content.json_dumps(json_data or {}).encode() @@ -129,11 +128,8 @@ def create_response(status_code, json_data=None): return response mock_post.side_effect = [ - create_response(429), - create_response(429), - create_response(429), - create_response(429), - create_response(200, mock_response), + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success ] model = "claude-3-5-sonnet-20240620" @@ -144,7 +140,7 @@ def create_response(status_code, json_data=None): assert reply_message.content == [Text(text="Hello from Claude!")] assert reply_usage.total_tokens == 35 - assert mock_post.call_count == 5 + assert mock_post.call_count == 2 mock_post.assert_any_call( "https://api.anthropic.com/v1/messages", json={ diff --git a/tests/providers/test_retry_with_back_off_decorator.py b/tests/providers/test_retry_with_back_off_decorator.py deleted file mode 100644 index 7418a9c..0000000 --- a/tests/providers/test_retry_with_back_off_decorator.py +++ /dev/null @@ -1,174 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request, retry_with_backoff -from httpx import HTTPStatusError, Response - - -def create_mock_function(): - mock_function = MagicMock() - mock_function.side_effect = [3, 5, 7] - return mock_function - - -def test_retry_with_backoff_retry_exhausted(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 7 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 5 - - assert mock_function.call_count == 2 - handle_retry_exhausted_function.assert_called_once() - handle_retry_exhausted_function.assert_called_with(5, 2) - - -def test_retry_with_backoff_retry_successful(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 4 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 5 - - assert mock_function.call_count == 2 - handle_retry_exhausted_function.assert_not_called() - - -def test_retry_with_backoff_without_retry(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 2 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 3 - - assert mock_function.call_count == 1 - handle_retry_exhausted_function.assert_not_called() - - -def create_mock_httpx_request_call_function(responses=[500, 429, 200]): - mock_function = MagicMock() - mock_responses = [] - for response_code in responses: - response = MagicMock() - response.status_code = response_code - mock_responses.append(response) - - mock_function.side_effect = mock_responses - return mock_function - - -def test_retry_httpx_request_backoff_retry_exhausted(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[500, 429], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - with pytest.raises(HTTPStatusError): - test_func() - - assert mock_httpx_request_call_function.call_count == 2 - - -def test_retry_httpx_request_backoff_retry_successful(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[500], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 429 - - assert mock_httpx_request_call_function.call_count == 2 - - -def test_retry_httpx_request_backoff_without_retry(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[503], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 500 - - assert mock_httpx_request_call_function.call_count == 1 - - -def test_retry_httpx_request_backoff_range(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[200]) - - @retry_httpx_request(max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 200 - - assert mock_httpx_request_call_function.call_count == 1 - - -def test_retry_httpx_request_backoff_range_retry_never_succeed(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 500]) - - @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - # Never gets a successful response - with pytest.raises(HTTPStatusError): - f = test_func() - # last error is 500 - assert f.status_code == 500 - - # Has been retried 3 times - assert mock_httpx_request_call_function.call_count == 3 - - -def test_retry_httpx_request_backoff_range_retry_succeed(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 200]) - - @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - # Retries and raises no error - f = test_func() - assert f.status_code == 200 - - # Has been retried 3 times - assert mock_httpx_request_call_function.call_count == 3