diff --git a/ai21/http_client/async_http_client.py b/ai21/http_client/async_http_client.py index 13fd536d..cc3f49ec 100644 --- a/ai21/http_client/async_http_client.py +++ b/ai21/http_client/async_http_client.py @@ -55,7 +55,7 @@ def __init__( wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES), retry=retry_if_result(self._should_retry), stop=stop_after_attempt(self._num_retries), - )(self._request) + )(self._run_request) self._streaming_decoder = _SSEDecoder() async def execute_http_request( @@ -103,11 +103,16 @@ async def execute_http_request( logger.error( f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}" ) - handle_non_success_response(response.status_code, response.text) + + if stream: + details = self._extract_streaming_error_details(response) + handle_non_success_response(response.status_code, details) + else: + handle_non_success_response(response.status_code, response.text) return response - async def _request(self, options: RequestOptions) -> httpx.Response: + async def _run_request(self, options: RequestOptions) -> httpx.Response: request = self._build_request(options) _logger.debug(f"Calling {request.method} {request.url} {request.headers}, {options.body}") diff --git a/ai21/http_client/base_http_client.py b/ai21/http_client/base_http_client.py index 669b82dc..4660aae7 100644 --- a/ai21/http_client/base_http_client.py +++ b/ai21/http_client/base_http_client.py @@ -118,7 +118,7 @@ def execute_http_request( pass @abstractmethod - def _request( + def _run_request( self, options: RequestOptions, ) -> httpx.Response: @@ -171,3 +171,9 @@ def _prepare_url(self, options: RequestOptions) -> str: return f"{options.url}{options.path}" return options.url + + def _extract_streaming_error_details(self, response: httpx.Response) -> str: + try: + return response.read().decode("utf-8") + except Exception: + return "could not extract streaming error details" diff --git a/ai21/http_client/http_client.py b/ai21/http_client/http_client.py index bacc10bb..9aa4a8e4 100644 --- a/ai21/http_client/http_client.py +++ b/ai21/http_client/http_client.py @@ -54,7 +54,7 @@ def __init__( wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES), retry=retry_if_result(self._should_retry), stop=stop_after_attempt(self._num_retries), - )(self._request) + )(self._run_request) self._streaming_decoder = _SSEDecoder() def execute_http_request( @@ -102,11 +102,16 @@ def execute_http_request( f"Calling {method} {self._base_url} failed with a non-200 " f"response code: {response.status_code} headers: {response.headers}" ) - handle_non_success_response(response.status_code, response.text) + + if stream: + details = self._extract_streaming_error_details(response) + handle_non_success_response(response.status_code, details) + else: + handle_non_success_response(response.status_code, response.text) return response - def _request(self, options: RequestOptions) -> httpx.Response: + def _run_request(self, options: RequestOptions) -> httpx.Response: request = self._build_request(options) _logger.debug(f"Calling {request.method} {request.url} {request.headers}, {options.body}") diff --git a/ai21/models/responses/embed_response.py b/ai21/models/responses/embed_response.py index 7ded40f6..6cb10921 100644 --- a/ai21/models/responses/embed_response.py +++ b/ai21/models/responses/embed_response.py @@ -6,6 +6,9 @@ class EmbedResult(AI21BaseModel): embedding: List[float] + def __init__(self, embedding: List[float]): + super().__init__(embedding=embedding) + class EmbedResponse(AI21BaseModel): id: str diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py index 9b68fbcc..02c44542 100644 --- a/tests/unittests/test_http_client.py +++ b/tests/unittests/test_http_client.py @@ -4,7 +4,7 @@ import httpx -from ai21.errors import ServiceUnavailable +from ai21.errors import ServiceUnavailable, Unauthorized from ai21.http_client.base_http_client import RETRY_ERROR_CODES from ai21.http_client.http_client import AI21HTTPClient from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -42,6 +42,17 @@ def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_htt assert mock_httpx_client.send.call_count == retries +def test__execute_http_request__when_streaming__should_handle_non_200_response_code(mock_httpx_client: Mock) -> None: + error_details = "test_error" + request = Request(method=_METHOD, url=_URL) + response = httpx.Response(status_code=401, request=request, text=error_details) + mock_httpx_client.send.return_value = response + + client = AI21HTTPClient(client=mock_httpx_client, base_url=_URL, api_key=_API_KEY) + with pytest.raises(Unauthorized, match=error_details): + client.execute_http_request(method=_METHOD, stream=True) + + @pytest.mark.asyncio async def test__execute_async_http_request__when_retry_error_code_once__should_retry_and_succeed( mock_httpx_async_client: Mock, @@ -74,3 +85,17 @@ async def test__execute_async_http_request__when_retry_error__should_retry_and_s await client.execute_http_request(method=_METHOD) assert mock_httpx_async_client.send.call_count == retries + + +@pytest.mark.asyncio +async def test__execute_async_http_request__when_streaming__should_handle_non_200_response_code( + mock_httpx_async_client: Mock, +) -> None: + error_details = "test_error" + request = Request(method=_METHOD, url=_URL) + response = httpx.Response(status_code=401, request=request, text=error_details) + mock_httpx_async_client.send.return_value = response + + client = AsyncAI21HTTPClient(client=mock_httpx_async_client, base_url=_URL, api_key=_API_KEY) + with pytest.raises(Unauthorized, match=error_details): + await client.execute_http_request(method=_METHOD, stream=True)