diff --git a/src/yandex_cloud_ml_sdk/_client.py b/src/yandex_cloud_ml_sdk/_client.py index 8856848..383deb7 100644 --- a/src/yandex_cloud_ml_sdk/_client.py +++ b/src/yandex_cloud_ml_sdk/_client.py @@ -81,15 +81,6 @@ def __init__( async def _init_service_map(self, timeout: float): credentials = grpc.ssl_channel_credentials() metadata = await self._get_metadata(auth_required=False, timeout=timeout, retry_kind=RetryKind.SINGLE) - - # XXX: assistants is not at https://api.cloud.yandex.net/endpoints yet - # But when it will be, this dict keys will be just overridden - # TODO: delete - self._service_map = { - 'ai-assistants': 'assistant.api.cloud.yandex.net', - 'ai-files': 'assistant.api.cloud.yandex.net', - } - async with grpc.aio.secure_channel( self._endpoint, credentials, @@ -159,9 +150,7 @@ def _new_channel(self, endpoint: str) -> grpc.aio.Channel: options=self._get_options(), ) - async def _get_channel( - self, stub_class: type[_T], timeout: float, service_name: str | None = None - ) -> grpc.aio.Channel: + async def _get_channel(self, stub_class: type[_T], timeout: float) -> grpc.aio.Channel: if stub_class in self._channels: return self._channels[stub_class] @@ -169,26 +158,30 @@ async def _get_channel( if stub_class in self._channels: return self._channels[stub_class] - if service_name is None: - service_name = _service_for_ctor(stub_class) - + service_name: str = _service_for_ctor(stub_class) if not self._service_map: await self._init_service_map(timeout=timeout) if not (endpoint := self._service_map.get(service_name)): - raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}') + # NB: this fix will work if service_map will change ai-assistant to ai-assistants + # (and retrospectively if user will stuck with this version) + # and if _service_for_ctor will change ai-assistants to ai-assistant + if service_name in ('ai-assistant', 'ai-assistants'): + service_name = 'ai-assistant' if service_name == 'ai-assistants' else 'ai-assistants' + if not (endpoint := self._service_map.get(service_name)): + raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}') + else: + raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}') channel = self._channels[stub_class] = self._new_channel(endpoint) return channel @asynccontextmanager - async def get_service_stub( - self, stub_class: type[_T], timeout: float, service_name: str | None = None - ) -> AsyncIterator[_T]: + async def get_service_stub(self, stub_class: type[_T], timeout: float) -> AsyncIterator[_T]: # NB: right now get_service_stub is asynccontextmanager and it is unnecessary, # but in future if we will make some ChannelPool, it could be handy to know, # when "user" releases resource - channel = await self._get_channel(stub_class, timeout, service_name=service_name) + channel = await self._get_channel(stub_class, timeout) yield stub_class(channel) async def call_service_stream( diff --git a/src/yandex_cloud_ml_sdk/_search_indexes/domain.py b/src/yandex_cloud_ml_sdk/_search_indexes/domain.py index 3397a92..967bfd9 100644 --- a/src/yandex_cloud_ml_sdk/_search_indexes/domain.py +++ b/src/yandex_cloud_ml_sdk/_search_indexes/domain.py @@ -67,9 +67,7 @@ async def _create_deferred( text_search_index=text_search_index, ) - async with self._client.get_service_stub( - SearchIndexServiceStub, timeout=timeout, service_name='ai-assistants' - ) as stub: + async with self._client.get_service_stub(SearchIndexServiceStub, timeout=timeout) as stub: response = await self._client.call_service( stub.Create, request, @@ -80,8 +78,7 @@ async def _create_deferred( return self._operation_type( id=response.id, sdk=self._sdk, - result_type=self._impl, - service_name='ai-assistants', + result_type=self._impl ) async def _get( diff --git a/src/yandex_cloud_ml_sdk/_types/operation.py b/src/yandex_cloud_ml_sdk/_types/operation.py index 34ad78d..7a1f333 100644 --- a/src/yandex_cloud_ml_sdk/_types/operation.py +++ b/src/yandex_cloud_ml_sdk/_types/operation.py @@ -87,14 +87,11 @@ async def _wait( class BaseOperation(OperationInterface[ResultTypeT]): _last_known_status: OperationStatus | None - def __init__( - self, sdk: BaseSDK, id: str, result_type: type[ResultTypeT], service_name: str | None = None - ): # pylint: disable=redefined-builtin + def __init__(self, sdk: BaseSDK, id: str, result_type: type[ResultTypeT]): # pylint: disable=redefined-builtin self._id = id self._sdk = sdk self._result_type: type[BaseResult] = result_type self._last_known_status = None - self._service_name = service_name @property def id(self): @@ -106,11 +103,7 @@ def _client(self): async def _get_status(self, *, timeout: float = 60) -> OperationStatus: request = GetOperationRequest(operation_id=self.id) - async with self._client.get_service_stub( - OperationServiceStub, - timeout=timeout, - service_name=self._service_name, - ) as stub: + async with self._client.get_service_stub(OperationServiceStub, timeout=timeout) as stub: response = await self._client.call_service( stub.Get, request, @@ -153,11 +146,7 @@ async def _get_result(self, *, timeout: float = 60) -> ResultTypeT: async def _cancel(self, *, timeout: float = 60) -> OperationStatus: request = CancelOperationRequest(operation_id=self.id) - async with self._client.get_service_stub( - OperationServiceStub, - timeout=timeout, - service_name=self._service_name, - ) as stub: + async with self._client.get_service_stub(OperationServiceStub, timeout=timeout) as stub: response = await self._client.call_service( stub.Cancel, request, diff --git a/tests/conftest.py b/tests/conftest.py index e509969..5847bc1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -153,7 +153,10 @@ def fixture_async_sdk( interceptors=interceptors, auth=auth, retry_policy=retry_policy, - service_map={ # TMP + service_map={ + # NOT SO TMP after all + # to remove this, we need to regenerate all of assistant tests cassetes + # and maybe change etalons in tests, so it needs some effort 'ai-files': 'assistant.api.cloud.yandex.net', 'ai-assistants': 'assistant.api.cloud.yandex.net', }