Skip to content

Commit

Permalink
Revert endpoints fix and bring new one(#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar authored Nov 13, 2024
1 parent e0fbff3 commit a98bd6b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 40 deletions.
33 changes: 13 additions & 20 deletions src/yandex_cloud_ml_sdk/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,6 @@ def __init__(
async def _init_service_map(self, timeout: float):
credentials = grpc.ssl_channel_credentials()
metadata = await self._get_metadata(auth_required=False, timeout=timeout, retry_kind=RetryKind.SINGLE)

# XXX: assistants is not at https://api.cloud.yandex.net/endpoints yet
# But when it will be, this dict keys will be just overridden
# TODO: delete
self._service_map = {
'ai-assistants': 'assistant.api.cloud.yandex.net',
'ai-files': 'assistant.api.cloud.yandex.net',
}

async with grpc.aio.secure_channel(
self._endpoint,
credentials,
Expand Down Expand Up @@ -159,36 +150,38 @@ def _new_channel(self, endpoint: str) -> grpc.aio.Channel:
options=self._get_options(),
)

async def _get_channel(
self, stub_class: type[_T], timeout: float, service_name: str | None = None
) -> grpc.aio.Channel:
async def _get_channel(self, stub_class: type[_T], timeout: float) -> grpc.aio.Channel:
if stub_class in self._channels:
return self._channels[stub_class]

async with self._channels_lock():
if stub_class in self._channels:
return self._channels[stub_class]

if service_name is None:
service_name = _service_for_ctor(stub_class)

service_name: str = _service_for_ctor(stub_class)
if not self._service_map:
await self._init_service_map(timeout=timeout)

if not (endpoint := self._service_map.get(service_name)):
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')
# NB: this fix will work if service_map will change ai-assistant to ai-assistants
# (and retrospectively if user will stuck with this version)
# and if _service_for_ctor will change ai-assistants to ai-assistant
if service_name in ('ai-assistant', 'ai-assistants'):
service_name = 'ai-assistant' if service_name == 'ai-assistants' else 'ai-assistants'
if not (endpoint := self._service_map.get(service_name)):
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')
else:
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')

channel = self._channels[stub_class] = self._new_channel(endpoint)
return channel

@asynccontextmanager
async def get_service_stub(
self, stub_class: type[_T], timeout: float, service_name: str | None = None
) -> AsyncIterator[_T]:
async def get_service_stub(self, stub_class: type[_T], timeout: float) -> AsyncIterator[_T]:
# NB: right now get_service_stub is asynccontextmanager and it is unnecessary,
# but in future if we will make some ChannelPool, it could be handy to know,
# when "user" releases resource
channel = await self._get_channel(stub_class, timeout, service_name=service_name)
channel = await self._get_channel(stub_class, timeout)
yield stub_class(channel)

async def call_service_stream(
Expand Down
7 changes: 2 additions & 5 deletions src/yandex_cloud_ml_sdk/_search_indexes/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ async def _create_deferred(
text_search_index=text_search_index,
)

async with self._client.get_service_stub(
SearchIndexServiceStub, timeout=timeout, service_name='ai-assistants'
) as stub:
async with self._client.get_service_stub(SearchIndexServiceStub, timeout=timeout) as stub:
response = await self._client.call_service(
stub.Create,
request,
Expand All @@ -80,8 +78,7 @@ async def _create_deferred(
return self._operation_type(
id=response.id,
sdk=self._sdk,
result_type=self._impl,
service_name='ai-assistants',
result_type=self._impl
)

async def _get(
Expand Down
17 changes: 3 additions & 14 deletions src/yandex_cloud_ml_sdk/_types/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ async def _wait(
class BaseOperation(OperationInterface[ResultTypeT]):
_last_known_status: OperationStatus | None

def __init__(
self, sdk: BaseSDK, id: str, result_type: type[ResultTypeT], service_name: str | None = None
): # pylint: disable=redefined-builtin
def __init__(self, sdk: BaseSDK, id: str, result_type: type[ResultTypeT]): # pylint: disable=redefined-builtin
self._id = id
self._sdk = sdk
self._result_type: type[BaseResult] = result_type
self._last_known_status = None
self._service_name = service_name

@property
def id(self):
Expand All @@ -106,11 +103,7 @@ def _client(self):

async def _get_status(self, *, timeout: float = 60) -> OperationStatus:
request = GetOperationRequest(operation_id=self.id)
async with self._client.get_service_stub(
OperationServiceStub,
timeout=timeout,
service_name=self._service_name,
) as stub:
async with self._client.get_service_stub(OperationServiceStub, timeout=timeout) as stub:
response = await self._client.call_service(
stub.Get,
request,
Expand Down Expand Up @@ -153,11 +146,7 @@ async def _get_result(self, *, timeout: float = 60) -> ResultTypeT:

async def _cancel(self, *, timeout: float = 60) -> OperationStatus:
request = CancelOperationRequest(operation_id=self.id)
async with self._client.get_service_stub(
OperationServiceStub,
timeout=timeout,
service_name=self._service_name,
) as stub:
async with self._client.get_service_stub(OperationServiceStub, timeout=timeout) as stub:
response = await self._client.call_service(
stub.Cancel,
request,
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def fixture_async_sdk(
interceptors=interceptors,
auth=auth,
retry_policy=retry_policy,
service_map={ # TMP
service_map={
# NOT SO TMP after all
# to remove this, we need to regenerate all of assistant tests cassetes
# and maybe change etalons in tests, so it needs some effort
'ai-files': 'assistant.api.cloud.yandex.net',
'ai-assistants': 'assistant.api.cloud.yandex.net',
}
Expand Down

0 comments on commit a98bd6b

Please sign in to comment.