diff --git a/ai21/http_client/async_http_client.py b/ai21/http_client/async_http_client.py index 13fd536d..a005e333 100644 --- a/ai21/http_client/async_http_client.py +++ b/ai21/http_client/async_http_client.py @@ -103,7 +103,12 @@ async def execute_http_request( logger.error( f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}" ) - handle_non_success_response(response.status_code, response.text) + + if stream: + details = self._extract_streaming_error_details(response) + handle_non_success_response(response.status_code, details) + else: + handle_non_success_response(response.status_code, response.text) return response diff --git a/ai21/http_client/base_http_client.py b/ai21/http_client/base_http_client.py index 669b82dc..8339715c 100644 --- a/ai21/http_client/base_http_client.py +++ b/ai21/http_client/base_http_client.py @@ -171,3 +171,6 @@ def _prepare_url(self, options: RequestOptions) -> str: return f"{options.url}{options.path}" return options.url + + def _extract_streaming_error_details(self, response: httpx.Response) -> str: + return response.read().decode("utf-8") diff --git a/ai21/http_client/http_client.py b/ai21/http_client/http_client.py index bacc10bb..bed3d994 100644 --- a/ai21/http_client/http_client.py +++ b/ai21/http_client/http_client.py @@ -102,7 +102,12 @@ def execute_http_request( f"Calling {method} {self._base_url} failed with a non-200 " f"response code: {response.status_code} headers: {response.headers}" ) - handle_non_success_response(response.status_code, response.text) + + if stream: + details = self._extract_streaming_error_details(response) + handle_non_success_response(response.status_code, details) + else: + handle_non_success_response(response.status_code, response.text) return response diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py index 9b68fbcc..02c44542 100644 --- a/tests/unittests/test_http_client.py +++ b/tests/unittests/test_http_client.py @@ -4,7 +4,7 @@ import httpx -from ai21.errors import ServiceUnavailable +from ai21.errors import ServiceUnavailable, Unauthorized from ai21.http_client.base_http_client import RETRY_ERROR_CODES from ai21.http_client.http_client import AI21HTTPClient from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -42,6 +42,17 @@ def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_htt assert mock_httpx_client.send.call_count == retries +def test__execute_http_request__when_streaming__should_handle_non_200_response_code(mock_httpx_client: Mock) -> None: + error_details = "test_error" + request = Request(method=_METHOD, url=_URL) + response = httpx.Response(status_code=401, request=request, text=error_details) + mock_httpx_client.send.return_value = response + + client = AI21HTTPClient(client=mock_httpx_client, base_url=_URL, api_key=_API_KEY) + with pytest.raises(Unauthorized, match=error_details): + client.execute_http_request(method=_METHOD, stream=True) + + @pytest.mark.asyncio async def test__execute_async_http_request__when_retry_error_code_once__should_retry_and_succeed( mock_httpx_async_client: Mock, @@ -74,3 +85,17 @@ async def test__execute_async_http_request__when_retry_error__should_retry_and_s await client.execute_http_request(method=_METHOD) assert mock_httpx_async_client.send.call_count == retries + + +@pytest.mark.asyncio +async def test__execute_async_http_request__when_streaming__should_handle_non_200_response_code( + mock_httpx_async_client: Mock, +) -> None: + error_details = "test_error" + request = Request(method=_METHOD, url=_URL) + response = httpx.Response(status_code=401, request=request, text=error_details) + mock_httpx_async_client.send.return_value = response + + client = AsyncAI21HTTPClient(client=mock_httpx_async_client, base_url=_URL, api_key=_API_KEY) + with pytest.raises(Unauthorized, match=error_details): + await client.execute_http_request(method=_METHOD, stream=True)