Skip to content

Commit

Permalink
feat: Rework error handling
Browse files Browse the repository at this point in the history
The previous implementation was too inclusive in many cases, and
also squashed the original error output which made it hard to identify
problems in setup. This approach should let us support retrying flakey
errors while not squashing real issues like access, incorrect model ids,
etc

Adding tenacity as a dep - it has an apache license, zero dependencies,
and provides a great standard for error handling
  • Loading branch information
baxen committed Sep 13, 2024
1 parent c4365bf commit dd5595e
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 298 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"jinja2>=3.1.4",
"tiktoken>=0.7.0",
"httpx>=0.27.0",
"tenacity>=9.0.0",
]

[tool.hatch.build.targets.wheel]
Expand Down
25 changes: 16 additions & 9 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AnthropicProvider(Provider):
def __init__(self, client: httpx.Client) -> None:
Expand Down Expand Up @@ -138,14 +146,13 @@ def complete(
)
payload = {k: v for k, v in payload.items() if v}

response = self._send_request(payload)

response_data = raise_for_status(response).json()
message = self.anthropic_response_to_message(response_data)
usage = self.get_usage(response_data)
response = self._post(payload)
message = self.anthropic_response_to_message(response)
usage = self.get_usage(response)

return message, usage

@retry_httpx_request()
def _send_request(self, payload: Dict[str, Any]) -> httpx.Response:
return self.client.post(ANTHROPIC_HOST, json=payload)
@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
return raise_for_status(response).json()
29 changes: 18 additions & 11 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -15,6 +16,13 @@
)
from exchange.tool import Tool

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AzureProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
Expand Down Expand Up @@ -91,18 +99,17 @@ def complete(

payload = {k: v for k, v in payload.items() if v}
request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}"
response = self._send_request(payload, request_url)
response = self._post(payload, request_url)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(data)
usage = self.get_usage(data)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401
return self.client.post(request_url, json=payload)
@retry_procedure
def _post(self, payload: Any, request_url: str) -> dict:
response = self.client.post(request_url, json=payload)
return raise_for_status(response).json()
26 changes: 17 additions & 9 deletions src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool

Expand All @@ -21,6 +22,13 @@

logger = logging.getLogger(__name__)

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AwsClient(httpx.Client):
def __init__(
Expand Down Expand Up @@ -110,7 +118,7 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name:
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request"
string_to_sign = (
f'{algorithm}\n{amz_date}\n{credential_scope}\n'
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
)

Expand Down Expand Up @@ -204,11 +212,10 @@ def complete(
payload = {k: v for k, v in payload.items() if v}

path = f"{self.client.host}model/{model}/converse"
response = self._send_request(payload, path)
raise_for_status(response)
response_message = response.json()["output"]["message"]
response = self._post(payload, path)
response_message = response["output"]["message"]

usage_data = response.json()["usage"]
usage_data = response["usage"]
usage = Usage(
input_tokens=usage_data.get("inputTokens"),
output_tokens=usage_data.get("outputTokens"),
Expand All @@ -217,9 +224,10 @@ def complete(

return self.response_to_message(response_message), usage

@retry_httpx_request()
def _send_request(self, payload: Any, path: str) -> httpx.Response: # noqa: ANN401
return self.client.post(path, json=payload)
@retry_procedure
def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401
response = self.client.post(path, json=payload)
return raise_for_status(response).json()

@staticmethod
def message_to_bedrock_spec(message: Message) -> dict:
Expand Down
26 changes: 17 additions & 9 deletions src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool


retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class DatabricksProvider(Provider):
"""Provides chat completions for models on Databricks serving endpoints
Expand Down Expand Up @@ -80,15 +88,15 @@ def complete(
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(model, payload)
data = raise_for_status(response).json()
message = openai_response_to_message(data)
usage = self.get_usage(data)
response = self._post(model, payload)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post(
@retry_procedure
def _post(self, model: str, payload: dict) -> httpx.Response:
response = self.client.post(
f"serving-endpoints/{model}/invocations",
json=payload,
)
return raise_for_status(response).json()
35 changes: 20 additions & 15 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -14,9 +13,18 @@
tools_to_openai_spec,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

OPENAI_HOST = "https://api.openai.com/"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
Expand Down Expand Up @@ -65,28 +73,25 @@ def complete(
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}]
payload = dict(
messages=[
{"role": "system", "content": system},
*messages_to_openai_spec(messages),
],
messages=system_message + messages_to_openai_spec(messages),
model=model,
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(payload)
response = self._post(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(data)
usage = self.get_usage(data)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)
@retry_procedure
def _post(self, payload: Any) -> dict:
response = self.client.post("v1/chat/completions", json=payload)
return raise_for_status(response).json()
61 changes: 0 additions & 61 deletions src/exchange/providers/retry_with_back_off_decorator.py

This file was deleted.

17 changes: 16 additions & 1 deletion src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
import base64
import json
import re
from typing import Any, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.tool import Tool
from tenacity import retry_if_exception


def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable:
codes = codes or []

def predicate(exc: Exception) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
if exc.response.status_code in codes:
return True
if above and exc.response.status_code >= above:
return True
return False

return retry_if_exception(predicate)


def raise_for_status(response: httpx.Response) -> httpx.Response:
Expand Down
Loading

0 comments on commit dd5595e

Please sign in to comment.