Skip to content

Commit

Permalink
Add grpc_credentials sdk option
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar committed Nov 15, 2024
1 parent 14ffe52 commit 35ba71b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/yandex_cloud_ml_sdk/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
yc_profile: str | None,
retry_policy: RetryPolicy,
enable_server_data_logging: bool | None,
credentials: grpc.ChannelCredentials | None,
):
self._endpoint = endpoint
self._auth = auth
Expand All @@ -77,6 +78,7 @@ def __init__(

self._user_agent = _get_user_agent()
self._enable_server_data_logging = enable_server_data_logging
self._credentials = credentials

async def _init_service_map(self, timeout: float):
credentials = grpc.ssl_channel_credentials()
Expand Down Expand Up @@ -142,7 +144,7 @@ def _get_options(self) -> tuple[tuple[str, str], ...]:
)

def _new_channel(self, endpoint: str) -> grpc.aio.Channel:
credentials = grpc.ssl_channel_credentials()
credentials = self._credentials or grpc.ssl_channel_credentials()
return grpc.aio.secure_channel(
endpoint,
credentials,
Expand Down
4 changes: 3 additions & 1 deletion src/yandex_cloud_ml_sdk/_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, Sequence

from get_annotations import get_annotations
from grpc import aio
from grpc import ChannelCredentials, aio

from ._assistants.domain import Assistants, AsyncAssistants, BaseAssistants
from ._auth import BaseAuth
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
service_map: UndefinedOr[dict[str, str]] = UNDEFINED,
interceptors: UndefinedOr[Sequence[aio.ClientInterceptor]] = UNDEFINED,
enable_server_data_logging: UndefinedOr[bool] = UNDEFINED,
grpc_credentials: UndefinedOr[ChannelCredentials] = UNDEFINED,
):
"""
Construct a new asynchronous sdk instance.
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
interceptors=get_defined_value(interceptors, None),
yc_profile=get_defined_value(yc_profile, None),
enable_server_data_logging=get_defined_value(enable_server_data_logging, None),
credentials=get_defined_value(grpc_credentials, None),
)
self._folder_id = folder_id

Expand Down
1 change: 1 addition & 0 deletions src/yandex_cloud_ml_sdk/_testing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
yc_profile=None,
retry_policy=retry_policy,
enable_server_data_logging=None,
credentials=None,
)
self.port = port
self._sdk = sdk
Expand Down
22 changes: 22 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from multiprocessing.pool import ThreadPool

import grpc
import pytest
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Token
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2 import (
Expand Down Expand Up @@ -159,6 +160,7 @@ async def test_x_data_logging(interceptors, retry_policy):
retry_policy=retry_policy,
interceptors=interceptors,
enable_server_data_logging=None,
credentials=None,
)

assert await client._get_metadata(auth_required=False, timeout=0) == base_result
Expand All @@ -171,6 +173,7 @@ async def test_x_data_logging(interceptors, retry_policy):
retry_policy=retry_policy,
interceptors=interceptors,
enable_server_data_logging=True,
credentials=None,
)

assert await client._get_metadata(auth_required=False, timeout=0) == base_result + (
Expand All @@ -185,8 +188,27 @@ async def test_x_data_logging(interceptors, retry_policy):
retry_policy=retry_policy,
interceptors=interceptors,
enable_server_data_logging=False,
credentials=None,
)

assert await client._get_metadata(auth_required=False, timeout=0) == base_result + (
('x-data-logging-enabled', "false"),
)


@pytest.mark.asyncio
async def test_channel_credentials(folder_id):
sdk = AsyncYCloudML(folder_id=folder_id)
assert sdk._client._credentials is None
sdk._client._new_channel('foo')

creds = grpc.ssl_channel_credentials()
sdk = AsyncYCloudML(folder_id=folder_id, grpc_credentials=creds)
assert sdk._client._credentials is creds
sdk._client._new_channel('foo')

# this test checks if passed grpc_credentials is really used in
# channel creation
sdk = AsyncYCloudML(folder_id=folder_id, grpc_credentials=1)
with pytest.raises(AttributeError, match="'int' object has no attribute '_credentials'"):
sdk._client._new_channel('foo')

0 comments on commit 35ba71b

Please sign in to comment.