diff --git a/src/yandex_cloud_ml_sdk/_client.py b/src/yandex_cloud_ml_sdk/_client.py index f535d06..ad85de1 100644 --- a/src/yandex_cloud_ml_sdk/_client.py +++ b/src/yandex_cloud_ml_sdk/_client.py @@ -2,6 +2,7 @@ from __future__ import annotations import sys +import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Literal, Protocol, Sequence, TypeVar, cast @@ -14,6 +15,7 @@ from yandex.cloud.endpoint.api_endpoint_service_pb2_grpc import ApiEndpointServiceStub from ._auth import BaseAuth, get_auth_provider +from ._exceptions import AioRpcError from ._retry import RETRY_KIND_METADATA_KEY, RetryKind, RetryPolicy from ._utils.lock import LazyLock from ._utils.proto import service_for_ctor @@ -72,6 +74,7 @@ def __init__( ) self._channels: dict[type[StubType], grpc.aio.Channel] = {} + self._endpoints: dict[type[StubType], str] = {} self._auth_lock = LazyLock() self._channels_lock = LazyLock() @@ -105,6 +108,7 @@ async def _get_metadata( ) -> tuple[tuple[str, str], ...]: metadata: tuple[tuple[str, str], ...] = ( (RETRY_KIND_METADATA_KEY, retry_kind.name), + ('x-client-request-id', str(uuid.uuid4())), ) if self._enable_server_data_logging is not None: @@ -176,6 +180,7 @@ async def _get_channel( else: raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}') + self._endpoints[stub_class] = endpoint channel = self._channels[stub_class] = self._new_channel(endpoint) return channel @@ -190,7 +195,20 @@ async def get_service_stub( # but in future if we will make some ChannelPool, it could be handy to know, # when "user" releases resource channel = await self._get_channel(stub_class, timeout, service_name=service_name) - yield stub_class(channel) + try: + yield stub_class(channel) + except grpc.aio.AioRpcError as original: + # .with_traceback(...) from None allows to mimic + # original exception without increasing traceback with an + # extra info, like + # "During handling of the above exception, another exception occurred" + # or # "The above exception was the direct cause of the following exception" + raise AioRpcError.from_base_rpc_error( + original, + endpoint=self._endpoints[stub_class], + auth=self._auth_provider, + stub_class=stub_class, + ).with_traceback(original.__traceback__) from None async def call_service_stream( self, diff --git a/src/yandex_cloud_ml_sdk/_exceptions.py b/src/yandex_cloud_ml_sdk/_exceptions.py new file mode 100644 index 0000000..3db3a07 --- /dev/null +++ b/src/yandex_cloud_ml_sdk/_exceptions.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from google.rpc.status_pb2 import Status as ProtoStatus # pylint: disable=no-name-in-module +from grpc import StatusCode +from grpc.aio import AioRpcError as BaseAioRpcError +from grpc.aio import Metadata + +from ._auth import BaseAuth + +if TYPE_CHECKING: + # pylint: disable=cyclic-import + from ._client import StubType + from ._datasets.validation import DatasetValidationResult + + +class YCloudMLError(Exception): + pass + + +class RunError(YCloudMLError): + def __init__(self, code: int, message: str, details: list[Any] | None, operation_id: str): + self.code = code + self.message = message + self.details = details or [] + self.operation_id = operation_id + + def __str__(self): + message = self.message or "" + message = f'Operation {self.operation_id} failed with message: {message} (code {self.code})' + message += '\n' + '\n'.join(repr(d) for d in self.details) + return message + + @classmethod + def from_proro_status(cls, status: ProtoStatus, operation_id: str): + return cls( + code=status.code, + message=status.message, + details=list(status.details) if status.details else None, + operation_id=operation_id, + ) + + +class AsyncOperationError(YCloudMLError): + pass + + +class WrongAsyncOperationStatusError(AsyncOperationError): + pass + + +class DatasetValidationError(AsyncOperationError): + def __init__(self, validation_result: DatasetValidationResult): + self._result = validation_result + + errors_str = '\n'.join(str(error) for error in self.errors) + message = f"Dataset validation for dataset_id={self.dataset_id} failed with next errors:\n{errors_str}" + super().__init__(message) + + @property + def errors(self): + return self._result.errors + + @property + def dataset_id(self): + return self._result.dataset_id + + +class AioRpcError(BaseAioRpcError): + _initial_metadata: Metadata | None + _trailing_metadata: Metadata | None + + def __init__( + self, + *args, + endpoint: str, + auth: BaseAuth | None, + stub_class: type[StubType], + **kwargs, + ): + super().__init__(*args, **kwargs) + self._endpoint = endpoint + self._auth = auth + self._stub_class = stub_class + + self._client_request_id: str + + initial = self._initial_metadata + trailing = self._trailing_metadata + + if ( + initial is not None and not isinstance(initial, Metadata) or + trailing is not None and not isinstance(trailing, Metadata) + ): + self._client_request_id = "grpc metadata was replaced with non-Metadata object" + else: + self._client_request_id = ( + initial and initial.get('x-client-request-id') or + trailing and trailing.get('x-client-request-id') or + "" + ) + + @classmethod + def from_base_rpc_error( + cls, + original: BaseAioRpcError, + endpoint: str, + auth: BaseAuth | None, + stub_class: type[StubType], + ) -> AioRpcError: + return cls( + code=original.code(), + initial_metadata=original.initial_metadata(), + trailing_metadata=original.trailing_metadata(), + details=original.details(), + debug_error_string=original.debug_error_string(), + endpoint=endpoint, + auth=auth, + stub_class=stub_class, + ).with_traceback(original.__traceback__) + + def __str__(self): + parts = [ + f"code = {self._code}", + f'details = "{self._details}"', + f'debug_error_string = "{self._debug_error_string}"', + f'endpoint = "{self._endpoint}"', + f'stub_class = {self._stub_class.__name__}' + ] + + if self._client_request_id: + parts.append(f'x-client-request-id = "{self._client_request_id}"') + + if self._code == StatusCode.UNAUTHENTICATED: + auth = self._auth.__class__.__name__ if self._auth else None + parts.append( + f"auth_provider = {auth}" + ) + + body = '\n'.join(f'\t{part}' for part in parts) + + return f"<{self.__class__.__name__} of RPC that terminated with:\n{body}\n>" diff --git a/src/yandex_cloud_ml_sdk/exceptions.py b/src/yandex_cloud_ml_sdk/exceptions.py index 8fd37f1..e92fa33 100644 --- a/src/yandex_cloud_ml_sdk/exceptions.py +++ b/src/yandex_cloud_ml_sdk/exceptions.py @@ -1,62 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING as _TYPE_CHECKING -from typing import Any as _Any - -from google.rpc.status_pb2 import Status as _ProtoStatus # pylint: disable=no-name-in-module - -if _TYPE_CHECKING: - # pylint: disable=cyclic-import - from yandex_cloud_ml_sdk._datasets.validation import DatasetValidationResult as _DatasetValidationResult - - -class YCloudMLError(Exception): - pass - - -class RunError(YCloudMLError): - def __init__(self, code: int, message: str, details: list[_Any] | None, operation_id: str): - self.code = code - self.message = message - self.details = details or [] - self.operation_id = operation_id - - def __str__(self): - message = self.message or "" - message = f'Operation {self.operation_id} failed with message: {message} (code {self.code})' - message += '\n' + '\n'.join(repr(d) for d in self.details) - return message - - @classmethod - def from_proro_status(cls, status: _ProtoStatus, operation_id: str): - return cls( - code=status.code, - message=status.message, - details=list(status.details) if status.details else None, - operation_id=operation_id, - ) - - -class AsyncOperationError(YCloudMLError): - pass - - -class WrongAsyncOperationStatusError(AsyncOperationError): - pass - - -class DatasetValidationError(AsyncOperationError): - def __init__(self, validation_result: _DatasetValidationResult): - self._result = validation_result - - errors_str = '\n'.join(str(error) for error in self.errors) - message = f"Dataset validation for dataset_id={self.dataset_id} failed with next errors:\n{errors_str}" - super().__init__(message) - - @property - def errors(self): - return self._result.errors - - @property - def dataset_id(self): - return self._result.dataset_id +from ._exceptions import ( + AioRpcError, AsyncOperationError, DatasetValidationError, RunError, WrongAsyncOperationStatusError, YCloudMLError +) + +__all__ = [ + 'AioRpcError', + 'AsyncOperationError', + 'DatasetValidationError', + 'RunError', + 'WrongAsyncOperationStatusError', + 'YCloudMLError', +] diff --git a/tests/auth/conftest.py b/tests/auth/conftest.py index 7d2c90f..b303c11 100644 --- a/tests/auth/conftest.py +++ b/tests/auth/conftest.py @@ -40,3 +40,12 @@ def terminate(self): pass return MockProcess + +@pytest.fixture(name="get_auth_meta") +def fixture_get_auth_meta(): + def getter(metadata): + for key, value in metadata: + if key == 'authorization': + return value + return None + return getter diff --git a/tests/auth/test_api_key.py b/tests/auth/test_api_key.py index 2c3bd27..ab74622 100644 --- a/tests/auth/test_api_key.py +++ b/tests/auth/test_api_key.py @@ -18,15 +18,12 @@ def fixture_auth(api_key): return APIKeyAuth(api_key) -async def test_auth(async_sdk, api_key): +async def test_auth(async_sdk, api_key, get_auth_meta): metadata = await async_sdk._client._get_metadata( auth_required=True, timeout=1 ) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ('authorization', f'Api-Key {api_key}'), - ) + assert get_auth_meta(metadata) == f'Api-Key {api_key}' async def test_applicable_from_env(api_key, monkeypatch): diff --git a/tests/auth/test_env_iam_token.py b/tests/auth/test_env_iam_token.py index 71ce024..d1a4114 100644 --- a/tests/auth/test_env_iam_token.py +++ b/tests/auth/test_env_iam_token.py @@ -19,25 +19,19 @@ def fixture_auth(iam_token, monkeypatch): return EnvIAMTokenAuth() -async def test_auth(async_sdk, iam_token, monkeypatch): +async def test_auth(async_sdk, iam_token, monkeypatch, get_auth_meta): metadata = await async_sdk._client._get_metadata( auth_required=True, timeout=1 ) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ('authorization', f'Bearer {iam_token}'), - ) + assert get_auth_meta(metadata) == f'Bearer {iam_token}' monkeypatch.setenv(EnvIAMTokenAuth.default_env_var, 'foo') metadata = await async_sdk._client._get_metadata( auth_required=True, timeout=1 ) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ('authorization', 'Bearer foo'), - ) + assert get_auth_meta(metadata) == 'Bearer foo' async def test_applicable_from_env(iam_token, monkeypatch): diff --git a/tests/auth/test_iam_token.py b/tests/auth/test_iam_token.py index cfb982b..7f45a6a 100644 --- a/tests/auth/test_iam_token.py +++ b/tests/auth/test_iam_token.py @@ -18,15 +18,12 @@ def fixture_auth(iam_token): return IAMTokenAuth(iam_token) -async def test_auth(async_sdk, iam_token): +async def test_auth(async_sdk, iam_token, get_auth_meta): metadata = await async_sdk._client._get_metadata( auth_required=True, timeout=1 ) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ('authorization', f'Bearer {iam_token}'), - ) + assert get_auth_meta(metadata) == f'Bearer {iam_token}' async def test_applicable_from_env(iam_token, monkeypatch): diff --git a/tests/auth/test_metadata.py b/tests/auth/test_metadata.py index 17331ad..fff54c4 100644 --- a/tests/auth/test_metadata.py +++ b/tests/auth/test_metadata.py @@ -22,7 +22,7 @@ def fixture_auth(): return MetadataAuth() -async def test_auth(async_sdk, iam_token, mock_client): +async def test_auth(async_sdk, iam_token, mock_client, get_auth_meta): response = httpx.Response( status_code=200, text=f'{{"access_token":"{iam_token}","expires_in":42055,"token_type":"Bearer"}}', @@ -31,10 +31,7 @@ async def test_auth(async_sdk, iam_token, mock_client): mock_client.get.return_value = response metadata = await async_sdk._client._get_metadata(auth_required=True, timeout=1) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ("authorization", f"Bearer {iam_token}"), - ) + assert get_auth_meta(metadata) == f"Bearer {iam_token}" async def test_reissue(async_sdk, auth, monkeypatch, mock_client): diff --git a/tests/auth/test_no_auth.py b/tests/auth/test_no_auth.py index a41845d..fa9a4c6 100644 --- a/tests/auth/test_no_auth.py +++ b/tests/auth/test_no_auth.py @@ -13,12 +13,12 @@ def fixture_auth(): return NoAuth() -async def test_auth(async_sdk): +async def test_auth(async_sdk, get_auth_meta): metadata = await async_sdk._client._get_metadata( auth_required=True, timeout=1 ) - assert metadata == (('yc-ml-sdk-retry', 'NONE'),) + assert get_auth_meta(metadata) is None async def test_applicable_from_env(): diff --git a/tests/auth/test_oauth_token.py b/tests/auth/test_oauth_token.py index 3984b7f..589d7b6 100644 --- a/tests/auth/test_oauth_token.py +++ b/tests/auth/test_oauth_token.py @@ -42,14 +42,11 @@ def Create(self, request, context): @pytest.mark.filterwarnings("ignore:.*OAuth:UserWarning") -async def test_auth(async_sdk, auth): +async def test_auth(async_sdk, auth, get_auth_meta): metadata = await async_sdk._client._get_metadata(auth_required=True, timeout=1) assert auth._issue_time is not None - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ("authorization", "Bearer "), - ) + assert get_auth_meta(metadata) == "Bearer " @pytest.mark.filterwarnings(r"ignore:.*OAuth:UserWarning") diff --git a/tests/auth/test_yc_cli.py b/tests/auth/test_yc_cli.py index 86753a3..c92e89d 100644 --- a/tests/auth/test_yc_cli.py +++ b/tests/auth/test_yc_cli.py @@ -21,16 +21,13 @@ def fixture_auth(): return YandexCloudCLIAuth() -async def test_auth(async_sdk, iam_token, monkeypatch, process_maker): +async def test_auth(async_sdk, iam_token, monkeypatch, process_maker, get_auth_meta): process = process_maker(stdout=b"Hello\n" + iam_token.encode("utf-8"), stderr=b"") mock_create_subprocess_exec = AsyncMock(return_value=process) monkeypatch.setattr("asyncio.create_subprocess_exec", mock_create_subprocess_exec) metadata = await async_sdk._client._get_metadata(auth_required=True, timeout=1) - assert metadata == ( - ('yc-ml-sdk-retry', 'NONE'), - ("authorization", f"Bearer {iam_token}"), - ) + assert get_auth_meta(metadata) == f"Bearer {iam_token}" async def test_reissue(async_sdk, auth, monkeypatch, process_maker): diff --git a/tests/test_client.py b/tests/test_client.py index 9f7b42c..37c6f13 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,6 +6,7 @@ from multiprocessing.pool import ThreadPool import grpc +import grpc.aio import pytest from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Token from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2 import ( @@ -18,9 +19,11 @@ from yandex.cloud.endpoint.api_endpoint_service_pb2 import ListApiEndpointsRequest, ListApiEndpointsResponse from yandex.cloud.endpoint.api_endpoint_service_pb2_grpc import ApiEndpointServiceStub +import yandex_cloud_ml_sdk._client from yandex_cloud_ml_sdk import AsyncYCloudML from yandex_cloud_ml_sdk._client import AsyncCloudClient, _get_user_agent, httpx_client from yandex_cloud_ml_sdk.auth import NoAuth +from yandex_cloud_ml_sdk.exceptions import AioRpcError @pytest.fixture(name='servicers') @@ -151,7 +154,18 @@ async def test_httpx_client(): # pylint: disable=protected-access @pytest.mark.asyncio async def test_x_data_logging(interceptors, retry_policy): - base_result = (('yc-ml-sdk-retry', 'NONE'),) + def check_result(metadata, extra=None): + retry = ('yc-ml-sdk-retry', 'NONE') + for key, value in metadata: + if retry == (key, value): + continue + if key == 'x-client-request-id': + continue + if (key, value) == extra: + continue + + assert not (key, value) + client = AsyncCloudClient( endpoint="foo", auth=NoAuth(), @@ -163,7 +177,7 @@ async def test_x_data_logging(interceptors, retry_policy): credentials=None, ) - assert await client._get_metadata(auth_required=False, timeout=0) == base_result + check_result(await client._get_metadata(auth_required=False, timeout=0)) client = AsyncCloudClient( endpoint="foo", @@ -176,7 +190,8 @@ async def test_x_data_logging(interceptors, retry_policy): credentials=None, ) - assert await client._get_metadata(auth_required=False, timeout=0) == base_result + ( + check_result( + await client._get_metadata(auth_required=False, timeout=0), ('x-data-logging-enabled', "true"), ) @@ -191,7 +206,8 @@ async def test_x_data_logging(interceptors, retry_policy): credentials=None, ) - assert await client._get_metadata(auth_required=False, timeout=0) == base_result + ( + check_result( + await client._get_metadata(auth_required=False, timeout=0), ('x-data-logging-enabled', "false"), ) @@ -212,3 +228,123 @@ async def test_channel_credentials(folder_id): sdk = AsyncYCloudML(folder_id=folder_id, grpc_credentials=1) with pytest.raises(AttributeError, match="'int' object has no attribute '_credentials'"): sdk._client._new_channel('foo') + + +@pytest.mark.asyncio +async def test_grpc_base_exception(async_sdk, monkeypatch, test_server): + result = await async_sdk.models.completions('foo').tokenize("bar") + assert result + + def raise_call_service(*args, **kwargs): + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.INTERNAL, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details="some details", + debug_error_string="some debug" + ) + + monkeypatch.setattr(yandex_cloud_ml_sdk._client.AsyncCloudClient, 'call_service', raise_call_service) + + with pytest.raises(AioRpcError) as exc_info: + await async_sdk.models.completions('foo').tokenize("bar") + exc = exc_info.value + exc_repr = str(exc) + + assert '"some details"' in exc_repr + assert '"some debug"' in exc_repr + assert f'\tendpoint = "localhost:{test_server.port}"\n' in exc_repr + assert '\tstub_class = TokenizerServiceStub\n' in exc_repr + assert '\tauth_provider' not in exc_repr + assert '\tx-client-request-id' not in exc_repr + + +@pytest.mark.asyncio +async def test_grpc_unauth_exception(async_sdk, monkeypatch, auth): + result = await async_sdk.models.completions('foo').tokenize("bar") + assert result + + def raise_call_service_unauth(*args, **kwargs): + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAUTHENTICATED, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details="some details", + debug_error_string="some debug" + ) + + monkeypatch.setattr(yandex_cloud_ml_sdk._client.AsyncCloudClient, 'call_service', raise_call_service_unauth) + + with pytest.raises(AioRpcError) as exc_info: + await async_sdk.models.completions('foo').tokenize("bar") + exc = exc_info.value + exc_repr = str(exc) + + assert f'\tauth_provider = {auth.__class__.__name__}\n' in exc_repr + assert '\tx-client-request-id' not in exc_repr + + +@pytest.mark.asyncio +async def test_grpc_request_id_in_initial_metadata_exception(async_sdk, monkeypatch): + def raise_call_service_initial(*args, **kwargs): + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.INTERNAL, + initial_metadata=grpc.aio.Metadata(('x-client-request-id', 'INITIAL')), + trailing_metadata=grpc.aio.Metadata(), + details="some details", + debug_error_string="some debug" + ) + + monkeypatch.setattr(yandex_cloud_ml_sdk._client.AsyncCloudClient, 'call_service', raise_call_service_initial) + + with pytest.raises(AioRpcError) as exc_info: + await async_sdk.models.completions('foo').tokenize("bar") + exc = exc_info.value + exc_repr = str(exc) + + assert '\tauth_provider' not in exc_repr + assert '\tx-client-request-id = "INITIAL"\n' in exc_repr + + +@pytest.mark.asyncio +async def test_grpc_request_id_in_trailing_metadata_exception(async_sdk, monkeypatch): + def raise_call_service_trailing(*args, **kwargs): + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.INTERNAL, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(('x-client-request-id', 'TRAILING')), + details="some details", + debug_error_string="some debug" + ) + + monkeypatch.setattr(yandex_cloud_ml_sdk._client.AsyncCloudClient, 'call_service', raise_call_service_trailing) + + with pytest.raises(AioRpcError) as exc_info: + await async_sdk.models.completions('foo').tokenize("bar") + exc = exc_info.value + exc_repr = str(exc) + + assert '\tauth_provider' not in exc_repr + assert '\tx-client-request-id = "TRAILING"\n' in exc_repr + + +@pytest.mark.asyncio +async def test_grpc_request_id_wrong_metadata_exception(async_sdk, monkeypatch): + def raise_call_service_wrong(*args, **kwargs): + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.INTERNAL, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=(), + details="some details", + debug_error_string="some debug" + ) + + monkeypatch.setattr(yandex_cloud_ml_sdk._client.AsyncCloudClient, 'call_service', raise_call_service_wrong) + + with pytest.raises(AioRpcError) as exc_info: + await async_sdk.models.completions('foo').tokenize("bar") + exc = exc_info.value + exc_repr = str(exc) + + assert '\tauth_provider' not in exc_repr + assert '\tx-client-request-id = "grpc metadata was replaced with non-Metadata object"\n' in exc_repr