Skip to content

Commit

Permalink
Fix typing and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar committed Dec 5, 2024
1 parent 251343e commit f35cbe9
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 81 deletions.
16 changes: 9 additions & 7 deletions examples/async/tuning/attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
sdk = AsyncYCloudML(
folder_id='b1ghsjum2v37c2un8h64',
)

async def get_datasets(sdk):
async for dataset in sdk.datasets.list(status="READY"):
print(f'using old dataset {dataset=}')
break
Expand All @@ -33,11 +29,17 @@ async def main() -> None:
dataset = await operation
print(f'created new dataset {dataset=}')

return dataset, dataset


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
train_dataset, validation_dataset = await get_datasets(sdk)
base_model = sdk.models.completions('yandexgpt-lite')

tuning_task = await base_model.tune_deferred(
dataset,
validation_datasets=dataset,
train_dataset,
validation_datasets=validation_dataset,
name=str(uuid.uuid4())
)
print(f'new {tuning_task=}')
Expand Down
16 changes: 9 additions & 7 deletions examples/async/tuning/cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
sdk = AsyncYCloudML(
folder_id='b1ghsjum2v37c2un8h64',
)

async def get_datasets(sdk):
async for dataset in sdk.datasets.list(status="READY"):
print(f'using old dataset {dataset=}')
break
Expand All @@ -33,11 +29,17 @@ async def main() -> None:
dataset = await operation
print(f'created new dataset {dataset=}')

return dataset, dataset


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
train_dataset, validation_dataset = await get_datasets(sdk)
base_model = sdk.models.completions('yandexgpt-lite')

tuning_task = await base_model.tune_deferred(
dataset,
validation_datasets=dataset,
train_dataset,
validation_datasets=validation_dataset,
name=str(uuid.uuid4())
)
print(f'new {tuning_task=}')
Expand Down
17 changes: 10 additions & 7 deletions examples/async/tuning/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
sdk = AsyncYCloudML(
folder_id='b1ghsjum2v37c2un8h64',
)

async def get_datasets(sdk):
async for dataset in sdk.datasets.list(status="READY"):
print(f'using old dataset {dataset=}')
break
Expand All @@ -33,11 +29,17 @@ async def main() -> None:
dataset = await operation
print(f'created new dataset {dataset=}')

return dataset, dataset


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
train_dataset, validation_dataset = await get_datasets(sdk)
base_model = sdk.models.completions('yandexgpt-lite')

tuning_task = await base_model.tune_deferred(
dataset,
validation_datasets=dataset,
train_dataset,
validation_datasets=validation_dataset,
name=str(uuid.uuid4())
)
print(f'new {tuning_task=}')
Expand All @@ -48,6 +50,7 @@ async def report_status():
print(f'{await tuning_task.get_status()=}')
print(f'{await tuning_task.get_task_info()=}')
print(f'{await tuning_task.get_metrics_url()=}')
print()
await asyncio.sleep(5)

report_task = asyncio.create_task(report_status())
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ disable = [
'fixme',
'missing-class-docstring', # TMP
'missing-function-docstring', # TMP
'duplicate-code', # TMP
]
load-plugins = [
'pylint_protobuf'
Expand Down Expand Up @@ -243,7 +244,7 @@ valid-classmethod-first-arg="cls"
valid-metaclass-classmethod-first-arg="cls"

[tool.pylint.'DESIGN']
max-args=15
max-args=16
max-attributes=15
max-bool-expr=5
max-branches=12
Expand Down
6 changes: 3 additions & 3 deletions src/yandex_cloud_ml_sdk/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from get_annotations import get_annotations

from yandex_cloud_ml_sdk._types.domain import BaseDomain
from yandex_cloud_ml_sdk._types.function import BaseFunction
from yandex_cloud_ml_sdk._types.function import BaseModelFunction
from yandex_cloud_ml_sdk._types.model import ModelTuneMixin

from .completions.function import AsyncCompletions, BaseCompletions, Completions
Expand All @@ -23,13 +23,13 @@ class BaseModels(BaseDomain):

def __init__(self, name: str, sdk: BaseSDK):
super().__init__(name=name, sdk=sdk)
self._tuning_map = {}
self._tuning_map: dict[str, type[ModelTuneMixin]] = {}
self._init_functions()

def _init_functions(self) -> None:
members: dict[str, type] = get_annotations(self.__class__, eval_str=True)
for member_name, member_class in members.items():
if not issubclass(member_class, BaseFunction):
if not issubclass(member_class, BaseModelFunction):
continue
function = member_class(name=member_name, sdk=self._sdk, parent_resource=self)
setattr(self, member_name, function)
Expand Down
23 changes: 0 additions & 23 deletions src/yandex_cloud_ml_sdk/_search_indexes/file.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
# pylint: disable=no-name-in-module
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any

from yandex.cloud.ai.assistants.v1.searchindex.search_index_file_pb2 import SearchIndexFile as ProtoSearchIndexFile

from yandex_cloud_ml_sdk._types.resource import BaseResource

from .chunking_strategy import BaseIndexChunkingStrategy

if TYPE_CHECKING:
from yandex_cloud_ml_sdk._sdk import BaseSDK


@dataclass(frozen=True)
class SearchIndexFile(BaseResource):
search_index_id: str
created_by: str
created_at: datetime
chunking_strategy: BaseIndexChunkingStrategy

@classmethod
def _kwargs_from_message(
cls,
proto: ProtoSearchIndexFile, # type: ignore[override]
sdk: BaseSDK
) -> dict[str, Any]:
kwargs = super()._kwargs_from_message(proto, sdk=sdk)
# pylint: disable=protected-access
kwargs['chunking_strategy'] = BaseIndexChunkingStrategy._from_upper_proto(
proto=proto.chunking_strategy, sdk=sdk
)
return kwargs
46 changes: 39 additions & 7 deletions src/yandex_cloud_ml_sdk/_tuning/domain.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# pylint: disable=protected-access,no-name-in-module
from __future__ import annotations

from typing import Generic, TypeVar, cast
from typing import AsyncIterator, Generic, Iterator

from yandex.cloud.ai.tuning.v1.tuning_service_pb2 import (
DescribeTuningRequest, DescribeTuningResponse, GetOptionsRequest, GetOptionsResponse, TuningRequest
)
from yandex.cloud.ai.tuning.v1.tuning_service_pb2 import GetOptionsRequest, GetOptionsResponse, TuningRequest
from yandex.cloud.ai.tuning.v1.tuning_service_pb2_grpc import TuningServiceStub
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation

from yandex_cloud_ml_sdk._types.domain import BaseDomain
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, PathLike, UndefinedOr, coerce_path, get_defined_value, is_defined
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
from yandex_cloud_ml_sdk._types.model import ModelTuneMixin
from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType, coerce_datasets
from yandex_cloud_ml_sdk._types.tuning.params import BaseTuningParams
Expand All @@ -30,7 +28,8 @@ def _to_weighted_datasets(
if not defined:
return ()

coerced = coerce_datasets(defined)
# mypy breaks here, it thinks `defined` have a type "object"
coerced = coerce_datasets(defined) # type: ignore[arg-type]

return tuple(
TuningRequest.WeightedDataset(
Expand Down Expand Up @@ -135,7 +134,7 @@ async def _list(
page_size: UndefinedOr[int] = UNDEFINED,
timeout: float = 60
) -> AsyncIterator[TuningTaskTypeT]:
pass
yield 1


class AsyncTuning(BaseTuning[AsyncTuningTask]):
Expand All @@ -149,6 +148,39 @@ async def get(
) -> AsyncTuningTask:
return await self._get(task_id=task_id, timeout=timeout)

async def list(
self,
*,
page_size: UndefinedOr[int] = UNDEFINED,
timeout: float = 60
) -> AsyncIterator[AsyncTuningTask]:
async for task in self._list(
page_size=page_size,
timeout=timeout
):
yield task


class Tuning(BaseTuning[TuningTask]):
_tuning_impl = TuningTask
__get = run_sync(BaseTuning._get)
__list = run_sync_generator(BaseTuning._list)

def get(
self,
task_id: str,
*,
timeout: float = 60,
) -> TuningTask:
return self.__get(task_id=task_id, timeout=timeout)

def list(
self,
*,
page_size: UndefinedOr[int] = UNDEFINED,
timeout: float = 60
) -> Iterator[TuningTask]:
yield from self.__list(
page_size=page_size,
timeout=timeout
)
26 changes: 17 additions & 9 deletions src/yandex_cloud_ml_sdk/_tuning/tuning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast

from grpc import StatusCode
from grpc.aio import AioRpcError
Expand Down Expand Up @@ -71,7 +71,7 @@ def _from_tuning_info(cls, info: TuningTaskInfo) -> TuningTaskStatus:
error = None
if info.status == TuningTaskStatusEnum.COMPLETED:
done = True
elif proto.status == TuningTaskStatusEnum.FAILED:
elif info.status == TuningTaskStatusEnum.FAILED:
done = True
error = OperationErrorInfo(
code=-1,
Expand All @@ -82,7 +82,7 @@ def _from_tuning_info(cls, info: TuningTaskInfo) -> TuningTaskStatus:
return cls(
done=done,
error=error,
response=proto,
response=info,
metadata=None,
)

Expand All @@ -102,7 +102,7 @@ def __init__(
self._operation_id = operation_id
self._task_id = task_id

if not operation_id and not task:
if not operation_id and not task_id:
raise TypeError('tuning task must be created with operation_id either with task_id')

@property
Expand All @@ -119,6 +119,9 @@ def _client(self):
async def _get_operation_id(self, *, timeout: float = 60) -> str | None:
if not self._operation_id:
task_info = await self._get_task_info(timeout=timeout)
if not task_info:
return None

self._operation_id = task_info.operation_id

return self._operation_id
Expand Down Expand Up @@ -199,7 +202,7 @@ async def _get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co:
status = await self._get_status(timeout=timeout)
if status.is_succeeded:
info = await self._get_task_info(timeout=timeout)
if not info.target_model_uri:
if not info or not info.target_model_uri:
raise WrongAsyncOperationStatusError(
f"tuning task {self._task_id} have COMPLETED status but empty target_model_uri"
)
Expand Down Expand Up @@ -228,6 +231,11 @@ async def _cancel(self, *, timeout: float = 60) -> None:
# 2) after operation expire

operation_id = await self._get_operation_id(timeout=timeout)
if not operation_id:
raise WrongAsyncOperationStatusError(
f"failed to cancel tuning task {self.id} because "
"it already gone from operations storage (few weeks)"
)

request = CancelOperationRequest(operation_id=operation_id)
async with self._client.get_service_stub(
Expand Down Expand Up @@ -268,7 +276,7 @@ async def _get_metrics_url(self, *, timeout: float = 60) -> str | None:


class AsyncTuningTask(BaseTuningTask[TuningResultTypeT_co]):
async def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo:
async def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
return await self._get_task_info(timeout=timeout)

async def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
Expand All @@ -293,7 +301,7 @@ async def wait(
poll_interval=poll_interval,
)

async def get_metrics_url(self, *, timeout: float = 60) -> str:
async def get_metrics_url(self, *, timeout: float = 60) -> str | None:
return await self._get_metrics_url(timeout=timeout)

def __await__(self):
Expand All @@ -308,7 +316,7 @@ class TuningTask(BaseTuningTask[TuningResultTypeT_co]):
__get_metrics_url = run_sync(BaseTuningTask._get_metrics_url)
__get_task_info = run_sync(BaseTuningTask._get_task_info)

def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo:
def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
return self.__get_task_info(timeout=timeout)

def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
Expand All @@ -334,7 +342,7 @@ def wait(
)
return cast(TuningResultTypeT_co, result)

def get_metrics_url(self, *, timeout: float = 60) -> str:
def get_metrics_url(self, *, timeout: float = 60) -> str | None:
return self.__get_metrics_url(timeout=timeout)


Expand Down
4 changes: 2 additions & 2 deletions src/yandex_cloud_ml_sdk/_types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def _tune(
self,
timeout: float = 60,
poll_timeout: int = 72 * 60 * 60,
poll_inteval: float = 60,
poll_interval: float = 60,
**kwargs,
) -> Self:
operation = await self._tune_deferred(
Expand All @@ -125,7 +125,7 @@ async def _tune(
result = await operation._wait(
timeout=timeout,
poll_timeout=poll_timeout,
poll_inteval=poll_inteval,
poll_interval=poll_interval,
)
return result

Expand Down
Loading

0 comments on commit f35cbe9

Please sign in to comment.