diff --git a/pyproject.toml b/pyproject.toml index f4c1a85..797e53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "jinja2>=3.1.4", "tiktoken>=0.7.0", "httpx>=0.27.0", + "tenacity>=9.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/src/exchange/providers/anthropic.py b/src/exchange/providers/anthropic.py index 7319ee8..154ec5f 100644 --- a/src/exchange/providers/anthropic.py +++ b/src/exchange/providers/anthropic.py @@ -6,11 +6,19 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AnthropicProvider(Provider): def __init__(self, client: httpx.Client) -> None: @@ -138,14 +146,13 @@ def complete( ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(payload) - - response_data = raise_for_status(response).json() - message = self.anthropic_response_to_message(response_data) - usage = self.get_usage(response_data) + response = self._post(payload) + message = self.anthropic_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Dict[str, Any]) -> httpx.Response: - return self.client.post(ANTHROPIC_HOST, json=payload) + @retry_procedure + def _post(self, payload: dict) -> httpx.Response: + response = self.client.post(ANTHROPIC_HOST, json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index 48e2747..a2afc4c 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -5,7 +5,8 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -15,6 +16,13 @@ ) from exchange.tool import Tool +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AzureProvider(Provider): """Provides chat completions for models hosted directly by OpenAI""" @@ -91,18 +99,17 @@ def complete( payload = {k: v for k, v in payload.items() if v} request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}" - response = self._send_request(payload, request_url) + response = self._post(payload, request_url) # Check for context_length_exceeded error for single, long input message - if "error" in response.json() and len(messages) == 1: - openai_single_message_context_length_exceeded(response.json()["error"]) - - data = raise_for_status(response).json() + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) - message = openai_response_to_message(data) - usage = self.get_usage(data) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401 - return self.client.post(request_url, json=payload) + @retry_procedure + def _post(self, payload: Any, request_url: str) -> dict: + response = self.client.post(request_url, json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index c51f69e..2a5f53d 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -12,7 +12,8 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.providers import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status from exchange.providers.utils import raise_for_status from exchange.tool import Tool @@ -21,6 +22,13 @@ logger = logging.getLogger(__name__) +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class AwsClient(httpx.Client): def __init__( @@ -110,7 +118,7 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: algorithm = "AWS4-HMAC-SHA256" credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request" string_to_sign = ( - f'{algorithm}\n{amz_date}\n{credential_scope}\n' + f"{algorithm}\n{amz_date}\n{credential_scope}\n" f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}' ) @@ -204,11 +212,10 @@ def complete( payload = {k: v for k, v in payload.items() if v} path = f"{self.client.host}model/{model}/converse" - response = self._send_request(payload, path) - raise_for_status(response) - response_message = response.json()["output"]["message"] + response = self._post(payload, path) + response_message = response["output"]["message"] - usage_data = response.json()["usage"] + usage_data = response["usage"] usage = Usage( input_tokens=usage_data.get("inputTokens"), output_tokens=usage_data.get("outputTokens"), @@ -217,9 +224,10 @@ def complete( return self.response_to_message(response_message), usage - @retry_httpx_request() - def _send_request(self, payload: Any, path: str) -> httpx.Response: # noqa: ANN401 - return self.client.post(path, json=payload) + @retry_procedure + def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + response = self.client.post(path, json=payload) + return raise_for_status(response).json() @staticmethod def message_to_bedrock_spec(message: Message) -> dict: diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index f946c52..84dc751 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -5,16 +5,24 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, - raise_for_status, tools_to_openai_spec, ) from exchange.tool import Tool +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + class DatabricksProvider(Provider): """Provides chat completions for models on Databricks serving endpoints @@ -80,15 +88,15 @@ def complete( **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(model, payload) - data = raise_for_status(response).json() - message = openai_response_to_message(data) - usage = self.get_usage(data) + response = self._post(model, payload) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401 - return self.client.post( + @retry_procedure + def _post(self, model: str, payload: dict) -> httpx.Response: + response = self.client.post( f"serving-endpoints/{model}/invocations", json=payload, ) + return raise_for_status(response).json() diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1f3133b..1f56eff 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -5,7 +5,6 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -14,9 +13,18 @@ tools_to_openai_spec, ) from exchange.tool import Tool +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status OPENAI_HOST = "https://api.openai.com/" +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + class OpenAiProvider(Provider): """Provides chat completions for models hosted directly by OpenAI""" @@ -65,28 +73,25 @@ def complete( tools: Tuple[Tool], **kwargs: Dict[str, Any], ) -> Tuple[Message, Usage]: + system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] payload = dict( - messages=[ - {"role": "system", "content": system}, - *messages_to_openai_spec(messages), - ], + messages=system_message + messages_to_openai_spec(messages), model=model, tools=tools_to_openai_spec(tools) if tools else [], **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self._send_request(payload) + response = self._post(payload) # Check for context_length_exceeded error for single, long input message - if "error" in response.json() and len(messages) == 1: - openai_single_message_context_length_exceeded(response.json()["error"]) - - data = raise_for_status(response).json() + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) - message = openai_response_to_message(data) - usage = self.get_usage(data) + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage - @retry_httpx_request() - def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 - return self.client.post("v1/chat/completions", json=payload) + @retry_procedure + def _post(self, payload: Any) -> dict: + response = self.client.post("v1/chat/completions", json=payload) + return raise_for_status(response).json() diff --git a/src/exchange/providers/retry_with_back_off_decorator.py b/src/exchange/providers/retry_with_back_off_decorator.py deleted file mode 100644 index c90c3d8..0000000 --- a/src/exchange/providers/retry_with_back_off_decorator.py +++ /dev/null @@ -1,61 +0,0 @@ -import time -from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Optional - -from httpx import HTTPStatusError, Response - - -def retry_with_backoff( - should_retry: Callable, - max_retries: Optional[int] = 5, - initial_wait: Optional[float] = 10, - backoff_factor: Optional[float] = 1, - handle_retry_exhausted: Optional[Callable] = None, -) -> Callable: - def decorator(func: Callable) -> Callable: - @wraps(func) - def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401 - result = None - for retry in range(max_retries): - result = func(*args, **kwargs) - if not should_retry(result): - return result - if (retry + 1) == max_retries: - break - sleep_time = initial_wait + (backoff_factor * (2**retry)) - time.sleep(sleep_time) - if handle_retry_exhausted: - handle_retry_exhausted(result, max_retries) - return result - - return wrapper - - return decorator - - -def retry_httpx_request( - retry_on_status_code: Optional[Iterable[int]] = None, - max_retries: Optional[int] = 5, - initial_wait: Optional[float] = 10, - backoff_factor: Optional[float] = 1, -) -> Callable: - if retry_on_status_code is None: - retry_on_status_code = set(range(401, 999)) - - def should_retry(response: Response) -> bool: - return response.status_code in retry_on_status_code - - def handle_retry_exhausted(response: Response, max_retries: int) -> None: - raise HTTPStatusError( - f"Failed after {max_retries}.", - request=response.request, - response=response, - ) - - return retry_with_backoff( - max_retries=max_retries, - initial_wait=initial_wait, - backoff_factor=backoff_factor, - should_retry=should_retry, - handle_retry_exhausted=handle_retry_exhausted, - ) diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index f53cbfc..4be7ac3 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -1,12 +1,27 @@ import base64 import json import re -from typing import Any, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.tool import Tool +from tenacity import retry_if_exception + + +def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: + codes = codes or [] + + def predicate(exc: Exception) -> bool: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code in codes: + return True + if above and exc.response.status_code >= above: + return True + return False + + return retry_if_exception(predicate) def raise_for_status(response: httpx.Response) -> httpx.Response: diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 37ea051..a6f5bc6 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -112,16 +112,15 @@ def test_messages_to_anthropic_spec() -> None: @patch("httpx.Client.post") -@patch("time.sleep", return_value=None) @patch("logging.warning") @patch("logging.error") -def test_anthropic_completion(mock_error, mock_warning, mock_sleep, mock_post, anthropic_provider): +def test_anthropic_completion(mock_error, mock_warning, mock_post, anthropic_provider): mock_response = { "content": [{"type": "text", "text": "Hello from Claude!"}], "usage": {"input_tokens": 10, "output_tokens": 25}, } - # First 4 attempts fail with status code 429, 5th succeeds + # First attempts fail with status code 429, 2nd succeeds def create_response(status_code, json_data=None): response = httpx.Response(status_code) response._content = httpx._content.json_dumps(json_data or {}).encode() @@ -129,11 +128,8 @@ def create_response(status_code, json_data=None): return response mock_post.side_effect = [ - create_response(429), - create_response(429), - create_response(429), - create_response(429), - create_response(200, mock_response), + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success ] model = "claude-3-5-sonnet-20240620" @@ -144,7 +140,7 @@ def create_response(status_code, json_data=None): assert reply_message.content == [Text(text="Hello from Claude!")] assert reply_usage.total_tokens == 35 - assert mock_post.call_count == 5 + assert mock_post.call_count == 2 mock_post.assert_any_call( "https://api.anthropic.com/v1/messages", json={ diff --git a/tests/providers/test_retry_with_back_off_decorator.py b/tests/providers/test_retry_with_back_off_decorator.py deleted file mode 100644 index 7418a9c..0000000 --- a/tests/providers/test_retry_with_back_off_decorator.py +++ /dev/null @@ -1,174 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request, retry_with_backoff -from httpx import HTTPStatusError, Response - - -def create_mock_function(): - mock_function = MagicMock() - mock_function.side_effect = [3, 5, 7] - return mock_function - - -def test_retry_with_backoff_retry_exhausted(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 7 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 5 - - assert mock_function.call_count == 2 - handle_retry_exhausted_function.assert_called_once() - handle_retry_exhausted_function.assert_called_with(5, 2) - - -def test_retry_with_backoff_retry_successful(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 4 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 5 - - assert mock_function.call_count == 2 - handle_retry_exhausted_function.assert_not_called() - - -def test_retry_with_backoff_without_retry(): - mock_function = create_mock_function() - handle_retry_exhausted_function = MagicMock() - - def should_try(result): - return result < 2 - - @retry_with_backoff( - should_retry=should_try, - max_retries=2, - initial_wait=0, - backoff_factor=0.001, - handle_retry_exhausted=handle_retry_exhausted_function, - ) - def test_func(): - return mock_function() - - assert test_func() == 3 - - assert mock_function.call_count == 1 - handle_retry_exhausted_function.assert_not_called() - - -def create_mock_httpx_request_call_function(responses=[500, 429, 200]): - mock_function = MagicMock() - mock_responses = [] - for response_code in responses: - response = MagicMock() - response.status_code = response_code - mock_responses.append(response) - - mock_function.side_effect = mock_responses - return mock_function - - -def test_retry_httpx_request_backoff_retry_exhausted(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[500, 429], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - with pytest.raises(HTTPStatusError): - test_func() - - assert mock_httpx_request_call_function.call_count == 2 - - -def test_retry_httpx_request_backoff_retry_successful(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[500], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 429 - - assert mock_httpx_request_call_function.call_count == 2 - - -def test_retry_httpx_request_backoff_without_retry(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function() - - @retry_httpx_request(retry_on_status_code=[503], max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 500 - - assert mock_httpx_request_call_function.call_count == 1 - - -def test_retry_httpx_request_backoff_range(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[200]) - - @retry_httpx_request(max_retries=2, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - assert test_func().status_code == 200 - - assert mock_httpx_request_call_function.call_count == 1 - - -def test_retry_httpx_request_backoff_range_retry_never_succeed(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 500]) - - @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - # Never gets a successful response - with pytest.raises(HTTPStatusError): - f = test_func() - # last error is 500 - assert f.status_code == 500 - - # Has been retried 3 times - assert mock_httpx_request_call_function.call_count == 3 - - -def test_retry_httpx_request_backoff_range_retry_succeed(): - mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 200]) - - @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) - def test_func() -> Response: - return mock_httpx_request_call_function() - - # Retries and raises no error - f = test_func() - assert f.status_code == 200 - - # Has been retried 3 times - assert mock_httpx_request_call_function.call_count == 3