Skip to content

Commit

Permalink
Add custom grpc exception with extra context
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar committed Dec 11, 2024
1 parent 152f4a0 commit a609a71
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 101 deletions.
20 changes: 19 additions & 1 deletion src/yandex_cloud_ml_sdk/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import sys
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Literal, Protocol, Sequence, TypeVar, cast
Expand All @@ -14,6 +15,7 @@
from yandex.cloud.endpoint.api_endpoint_service_pb2_grpc import ApiEndpointServiceStub

from ._auth import BaseAuth, get_auth_provider
from ._exceptions import AioRpcError
from ._retry import RETRY_KIND_METADATA_KEY, RetryKind, RetryPolicy
from ._utils.lock import LazyLock
from ._utils.proto import service_for_ctor
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(
)

self._channels: dict[type[StubType], grpc.aio.Channel] = {}
self._endpoints: dict[type[StubType], str] = {}

self._auth_lock = LazyLock()
self._channels_lock = LazyLock()
Expand Down Expand Up @@ -105,6 +108,7 @@ async def _get_metadata(
) -> tuple[tuple[str, str], ...]:
metadata: tuple[tuple[str, str], ...] = (
(RETRY_KIND_METADATA_KEY, retry_kind.name),
('x-client-request-id', str(uuid.uuid4())),
)

if self._enable_server_data_logging is not None:
Expand Down Expand Up @@ -176,6 +180,7 @@ async def _get_channel(
else:
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')

self._endpoints[stub_class] = endpoint
channel = self._channels[stub_class] = self._new_channel(endpoint)
return channel

Expand All @@ -190,7 +195,20 @@ async def get_service_stub(
# but in future if we will make some ChannelPool, it could be handy to know,
# when "user" releases resource
channel = await self._get_channel(stub_class, timeout, service_name=service_name)
yield stub_class(channel)
try:
yield stub_class(channel)
except grpc.aio.AioRpcError as original:
# .with_traceback(...) from None allows to mimic
# original exception without increasing traceback with an
# extra info, like
# "During handling of the above exception, another exception occurred"
# or # "The above exception was the direct cause of the following exception"
raise AioRpcError.from_base_rpc_error(
original,
endpoint=self._endpoints[stub_class],
auth=self._auth_provider,
stub_class=stub_class,
).with_traceback(original.__traceback__) from None

async def call_service_stream(
self,
Expand Down
143 changes: 143 additions & 0 deletions src/yandex_cloud_ml_sdk/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from google.rpc.status_pb2 import Status as ProtoStatus # pylint: disable=no-name-in-module
from grpc import StatusCode
from grpc.aio import AioRpcError as BaseAioRpcError
from grpc.aio import Metadata

from ._auth import BaseAuth

if TYPE_CHECKING:
# pylint: disable=cyclic-import
from ._client import StubType
from ._datasets.validation import DatasetValidationResult


class YCloudMLError(Exception):
pass


class RunError(YCloudMLError):
def __init__(self, code: int, message: str, details: list[Any] | None, operation_id: str):
self.code = code
self.message = message
self.details = details or []
self.operation_id = operation_id

def __str__(self):
message = self.message or "<Empty message>"
message = f'Operation {self.operation_id} failed with message: {message} (code {self.code})'
message += '\n' + '\n'.join(repr(d) for d in self.details)
return message

@classmethod
def from_proro_status(cls, status: ProtoStatus, operation_id: str):
return cls(
code=status.code,
message=status.message,
details=list(status.details) if status.details else None,
operation_id=operation_id,
)


class AsyncOperationError(YCloudMLError):
pass


class WrongAsyncOperationStatusError(AsyncOperationError):
pass


class DatasetValidationError(AsyncOperationError):
def __init__(self, validation_result: DatasetValidationResult):
self._result = validation_result

errors_str = '\n'.join(str(error) for error in self.errors)
message = f"Dataset validation for dataset_id={self.dataset_id} failed with next errors:\n{errors_str}"
super().__init__(message)

@property
def errors(self):
return self._result.errors

@property
def dataset_id(self):
return self._result.dataset_id


class AioRpcError(BaseAioRpcError):
_initial_metadata: Metadata | None
_trailing_metadata: Metadata | None

def __init__(
self,
*args,
endpoint: str,
auth: BaseAuth | None,
stub_class: type[StubType],
**kwargs,
):
super().__init__(*args, **kwargs)
self._endpoint = endpoint
self._auth = auth
self._stub_class = stub_class

self._client_request_id: str

initial = self._initial_metadata
trailing = self._trailing_metadata

if (
initial is not None and not isinstance(initial, Metadata) or
trailing is not None and not isinstance(trailing, Metadata)
):
self._client_request_id = "grpc metadata was replaced with non-Metadata object"
else:
self._client_request_id = (
initial and initial.get('x-client-request-id') or
trailing and trailing.get('x-client-request-id') or
""
)

@classmethod
def from_base_rpc_error(
cls,
original: BaseAioRpcError,
endpoint: str,
auth: BaseAuth | None,
stub_class: type[StubType],
) -> AioRpcError:
return cls(
code=original.code(),
initial_metadata=original.initial_metadata(),
trailing_metadata=original.trailing_metadata(),
details=original.details(),
debug_error_string=original.debug_error_string(),
endpoint=endpoint,
auth=auth,
stub_class=stub_class,
).with_traceback(original.__traceback__)

def __str__(self):
parts = [
f"code = {self._code}",
f'details = "{self._details}"',
f'debug_error_string = "{self._debug_error_string}"',
f'endpoint = "{self._endpoint}"',
f'stub_class = {self._stub_class.__name__}'
]

if self._client_request_id:
parts.append(f'x-client-request-id = "{self._client_request_id}"')

if self._code == StatusCode.UNAUTHENTICATED:
auth = self._auth.__class__.__name__ if self._auth else None
parts.append(
f"auth_provider = {auth}"
)

body = '\n'.join(f'\t{part}' for part in parts)

return f"<{self.__class__.__name__} of RPC that terminated with:\n{body}\n>"
72 changes: 12 additions & 60 deletions src/yandex_cloud_ml_sdk/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING as _TYPE_CHECKING
from typing import Any as _Any

from google.rpc.status_pb2 import Status as _ProtoStatus # pylint: disable=no-name-in-module

if _TYPE_CHECKING:
# pylint: disable=cyclic-import
from yandex_cloud_ml_sdk._datasets.validation import DatasetValidationResult as _DatasetValidationResult


class YCloudMLError(Exception):
pass


class RunError(YCloudMLError):
def __init__(self, code: int, message: str, details: list[_Any] | None, operation_id: str):
self.code = code
self.message = message
self.details = details or []
self.operation_id = operation_id

def __str__(self):
message = self.message or "<Empty message>"
message = f'Operation {self.operation_id} failed with message: {message} (code {self.code})'
message += '\n' + '\n'.join(repr(d) for d in self.details)
return message

@classmethod
def from_proro_status(cls, status: _ProtoStatus, operation_id: str):
return cls(
code=status.code,
message=status.message,
details=list(status.details) if status.details else None,
operation_id=operation_id,
)


class AsyncOperationError(YCloudMLError):
pass


class WrongAsyncOperationStatusError(AsyncOperationError):
pass


class DatasetValidationError(AsyncOperationError):
def __init__(self, validation_result: _DatasetValidationResult):
self._result = validation_result

errors_str = '\n'.join(str(error) for error in self.errors)
message = f"Dataset validation for dataset_id={self.dataset_id} failed with next errors:\n{errors_str}"
super().__init__(message)

@property
def errors(self):
return self._result.errors

@property
def dataset_id(self):
return self._result.dataset_id
from ._exceptions import (
AioRpcError, AsyncOperationError, DatasetValidationError, RunError, WrongAsyncOperationStatusError, YCloudMLError
)

__all__ = [
'AioRpcError',
'AsyncOperationError',
'DatasetValidationError',
'RunError',
'WrongAsyncOperationStatusError',
'YCloudMLError',
]
9 changes: 9 additions & 0 deletions tests/auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,12 @@ def terminate(self):
pass

return MockProcess

@pytest.fixture(name="get_auth_meta")
def fixture_get_auth_meta():
def getter(metadata):
for key, value in metadata:
if key == 'authorization':
return value
return None
return getter
7 changes: 2 additions & 5 deletions tests/auth/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ def fixture_auth(api_key):
return APIKeyAuth(api_key)


async def test_auth(async_sdk, api_key):
async def test_auth(async_sdk, api_key, get_auth_meta):
metadata = await async_sdk._client._get_metadata(
auth_required=True,
timeout=1
)
assert metadata == (
('yc-ml-sdk-retry', 'NONE'),
('authorization', f'Api-Key {api_key}'),
)
assert get_auth_meta(metadata) == f'Api-Key {api_key}'


async def test_applicable_from_env(api_key, monkeypatch):
Expand Down
12 changes: 3 additions & 9 deletions tests/auth/test_env_iam_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,19 @@ def fixture_auth(iam_token, monkeypatch):
return EnvIAMTokenAuth()


async def test_auth(async_sdk, iam_token, monkeypatch):
async def test_auth(async_sdk, iam_token, monkeypatch, get_auth_meta):
metadata = await async_sdk._client._get_metadata(
auth_required=True,
timeout=1
)
assert metadata == (
('yc-ml-sdk-retry', 'NONE'),
('authorization', f'Bearer {iam_token}'),
)
assert get_auth_meta(metadata) == f'Bearer {iam_token}'

monkeypatch.setenv(EnvIAMTokenAuth.default_env_var, 'foo')
metadata = await async_sdk._client._get_metadata(
auth_required=True,
timeout=1
)
assert metadata == (
('yc-ml-sdk-retry', 'NONE'),
('authorization', 'Bearer foo'),
)
assert get_auth_meta(metadata) == 'Bearer foo'


async def test_applicable_from_env(iam_token, monkeypatch):
Expand Down
7 changes: 2 additions & 5 deletions tests/auth/test_iam_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ def fixture_auth(iam_token):
return IAMTokenAuth(iam_token)


async def test_auth(async_sdk, iam_token):
async def test_auth(async_sdk, iam_token, get_auth_meta):
metadata = await async_sdk._client._get_metadata(
auth_required=True,
timeout=1
)
assert metadata == (
('yc-ml-sdk-retry', 'NONE'),
('authorization', f'Bearer {iam_token}'),
)
assert get_auth_meta(metadata) == f'Bearer {iam_token}'


async def test_applicable_from_env(iam_token, monkeypatch):
Expand Down
7 changes: 2 additions & 5 deletions tests/auth/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fixture_auth():
return MetadataAuth()


async def test_auth(async_sdk, iam_token, mock_client):
async def test_auth(async_sdk, iam_token, mock_client, get_auth_meta):
response = httpx.Response(
status_code=200,
text=f'{{"access_token":"{iam_token}","expires_in":42055,"token_type":"Bearer"}}',
Expand All @@ -31,10 +31,7 @@ async def test_auth(async_sdk, iam_token, mock_client):
mock_client.get.return_value = response

metadata = await async_sdk._client._get_metadata(auth_required=True, timeout=1)
assert metadata == (
('yc-ml-sdk-retry', 'NONE'),
("authorization", f"Bearer {iam_token}"),
)
assert get_auth_meta(metadata) == f"Bearer {iam_token}"


async def test_reissue(async_sdk, auth, monkeypatch, mock_client):
Expand Down
4 changes: 2 additions & 2 deletions tests/auth/test_no_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def fixture_auth():
return NoAuth()


async def test_auth(async_sdk):
async def test_auth(async_sdk, get_auth_meta):
metadata = await async_sdk._client._get_metadata(
auth_required=True,
timeout=1
)
assert metadata == (('yc-ml-sdk-retry', 'NONE'),)
assert get_auth_meta(metadata) is None


async def test_applicable_from_env():
Expand Down
Loading

0 comments on commit a609a71

Please sign in to comment.