From 73dd7dcd6b02f1e1834658456a68774e942fa5d9 Mon Sep 17 00:00:00 2001 From: Vladimir Lipkin Date: Tue, 5 Nov 2024 19:26:55 +0100 Subject: [PATCH] Merge 0.1.2 to a master (#21) * Add typing_extensions dep * Fix for langchain >= 0.3 (#19) * Support for 3.13 (#20) * Bump version * Fix for mypy dataclass inheritance --- .github/workflows/tests.yaml | 2 +- README.md | 8 +++++ conftest.py | 30 +++++++++++++++++++ pyproject.toml | 16 ++++++++-- .../_models/completions/langchain.py | 12 ++++---- src/yandex_cloud_ml_sdk/_runs/result.py | 23 +++++++------- src/yandex_cloud_ml_sdk/_types/langchain.py | 19 +++++++----- tests/conftest.py | 6 ---- tox.ini | 6 ++-- 9 files changed, 87 insertions(+), 35 deletions(-) create mode 100644 conftest.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ff8bb8b..e2f2b00 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ] + python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ] env: [ '', '-extra-deps'] timeout-minutes: 10 steps: diff --git a/README.md b/README.md index d461da1..5664e84 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,14 @@ For more usage examples look into `examples` folder. ## Langchain integration +To use langchain integration, install `yandex-cloud-ml-sdk` package with a `langchain` extra: + +```sh +pip install yandex-cloud-ml-sdk[langchain] +``` + +Usage example: + ```python from yandex_cloud_ml_sdk import YCloudML from langchain_core.messages import AIMessage, HumanMessage diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..b2d10d9 --- /dev/null +++ b/conftest.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import pathlib +import sys + +pytest_plugins = [ + 'pytest_asyncio', + 'pytest_recording', + 'yandex_cloud_ml_sdk._testing.plugin', +] + +langchain_paths = [ + 'examples/langchain/', + 'tests/langchain_', + 'src/yandex_cloud_ml_sdk/_types/langchain.py', + 'src/yandex_cloud_ml_sdk/_utils/langchain.py', + 'src/yandex_cloud_ml_sdk/_models/completions/langchain.py', +] + +def pytest_ignore_collect(collection_path, path, config): # pylint: disable=unused-argument + if sys.version_info > (3, 9): + return None + + base_path = pathlib.Path(__file__).parent + for suffix in langchain_paths: + path_to_ignore = base_path / suffix + if str(collection_path).startswith(str(path_to_ignore)): + return True + + return None diff --git a/pyproject.toml b/pyproject.toml index 063e663..7a98d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", @@ -46,7 +47,7 @@ test = [ "tox", ] langchain = [ - "langchain-core>=0.2.29,<=0.2.41" + "langchain-core>=0.3" ] [project.urls] @@ -95,6 +96,9 @@ asyncio_default_fixture_loop_scope = "function" filterwarnings = [ "ignore::DeprecationWarning:pytest_pylint" ] +addopts = [ + "--import-mode=importlib", +] [tool.pylint] ignore="CVS" @@ -263,7 +267,15 @@ overgeneral-exceptions= [ [[tool.mypy.overrides]] module = [ "get_annotations", - "scipy.spatial.distance", + "langchain_core", + "langchain_core.callbacks", + "langchain_core.callbacks.manager", + "langchain_core.language_models.base", + "langchain_core.language_models.chat_models", "langchain_core.messages", + "langchain_core.messages.ai", + "langchain_core.outputs", + "pydantic", + "scipy.spatial.distance", ] ignore_missing_imports = true diff --git a/src/yandex_cloud_ml_sdk/_models/completions/langchain.py b/src/yandex_cloud_ml_sdk/_models/completions/langchain.py index 0132b9d..b3e17e7 100644 --- a/src/yandex_cloud_ml_sdk/_models/completions/langchain.py +++ b/src/yandex_cloud_ml_sdk/_models/completions/langchain.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import asdict -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar +from typing import Any, AsyncIterator, Iterator, TypeVar from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel @@ -14,12 +14,9 @@ from yandex_cloud_ml_sdk._utils.sync import run_sync_generator_impl, run_sync_impl from .message import TextMessageDict +from .model import BaseGPTModel # pylint: disable=cyclic-import from .result import Alternative, AlternativeStatus, GPTModelResult -if TYPE_CHECKING: - from .model import BaseGPTModel # noqa - - GenerationClassT = TypeVar('GenerationClassT', bound=ChatGeneration) @@ -53,7 +50,7 @@ def _transform_messages(history: list[BaseMessage]) -> list[TextMessageDict]: return chat_history -class ChatYandexGPT(BaseYandexLanguageModel['BaseGPTModel'], BaseChatModel): +class ChatYandexGPT(BaseYandexLanguageModel[BaseGPTModel], BaseChatModel): class Config: arbitrary_types_allowed = True @@ -182,3 +179,6 @@ async def _astream( text_override=delta, ) yield generation + + +ChatYandexGPT.model_rebuild() diff --git a/src/yandex_cloud_ml_sdk/_runs/result.py b/src/yandex_cloud_ml_sdk/_runs/result.py index b65f512..41b8a88 100644 --- a/src/yandex_cloud_ml_sdk/_runs/result.py +++ b/src/yandex_cloud_ml_sdk/_runs/result.py @@ -3,7 +3,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, TypeVar from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent @@ -17,13 +17,19 @@ if TYPE_CHECKING: from yandex_cloud_ml_sdk._sdk import BaseSDK +StatusTypeT = TypeVar('StatusTypeT', bound=BaseRunStatus) +MessageTypeT = TypeVar('MessageTypeT', bound=BaseMessage) @dataclasses.dataclass(frozen=True) -class BaseRunResult(BaseRunStatus, BaseResult[ProtoResultTypeT]): - status: BaseRunStatus +class BaseRunResult( + BaseRunStatus, + BaseResult[ProtoResultTypeT], + Generic[ProtoResultTypeT, StatusTypeT, MessageTypeT] +): + status: StatusTypeT error: str | None - _message: BaseMessage | None + _message: MessageTypeT | None @classmethod @abc.abstractmethod @@ -43,7 +49,7 @@ def is_failed(self) -> bool: return self.status.is_failed @property - def message(self) -> BaseMessage: + def message(self) -> MessageTypeT: if self.is_failed: raise ValueError("run is failed and don't have a message result") assert self._message @@ -59,11 +65,9 @@ def parts(self) -> tuple[Any]: @dataclasses.dataclass(frozen=True) -class RunResult(BaseRunResult[ProtoRun]): +class RunResult(BaseRunResult[ProtoRun, RunStatus, Message]): _proto_result_type = ProtoRun - _message: Message | None - status: RunStatus usage: Usage | None @classmethod @@ -100,9 +104,8 @@ def _from_proto(cls, proto: ProtoRun, sdk: BaseSDK) -> RunResult: @dataclasses.dataclass(frozen=True) -class RunStreamEvent(BaseRunResult[ProtoStreamEvent]): +class RunStreamEvent(BaseRunResult[ProtoStreamEvent, StreamEvent, BaseMessage]): _proto_result_type = ProtoStreamEvent - status: StreamEvent @classmethod def _from_proto(cls, proto: ProtoStreamEvent, sdk: BaseSDK) -> RunStreamEvent: diff --git a/src/yandex_cloud_ml_sdk/_types/langchain.py b/src/yandex_cloud_ml_sdk/_types/langchain.py index 3faf2b2..204b6ca 100644 --- a/src/yandex_cloud_ml_sdk/_types/langchain.py +++ b/src/yandex_cloud_ml_sdk/_types/langchain.py @@ -1,24 +1,29 @@ -# pylint: disable=abstract-method +# pylint: disable=abstract-method,wrong-import-position from __future__ import annotations +import sys + +if sys.version_info < (3, 9): + raise NotImplementedError("Langchain integration doesn't supported for python<3.9") + from typing import Generic, TypeVar from langchain_core.language_models.base import BaseLanguageModel -from langchain_core.load import Serializable -from langchain_core.pydantic_v1 import BaseModel as LangchainModel +from pydantic import BaseModel as PydanticModel +from pydantic import ConfigDict from yandex_cloud_ml_sdk._types.model import BaseModel ModelTypeT = TypeVar('ModelTypeT', bound=BaseModel) -class BaseYandexModel(LangchainModel, Generic[ModelTypeT]): +class BaseYandexModel(PydanticModel, Generic[ModelTypeT]): ycmlsdk_model: ModelTypeT timeout: int = 60 - class Config(Serializable.Config): - arbitrary_types_allowed = True - + model_config = ConfigDict( + arbitrary_types_allowed=True + ) class BaseYandexLanguageModel(BaseYandexModel[ModelTypeT], BaseLanguageModel): diff --git a/tests/conftest.py b/tests/conftest.py index 4aa7f28..efe8caa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,12 +18,6 @@ ) from yandex_cloud_ml_sdk._types.misc import UNDEFINED -pytest_plugins = [ - 'pytest_asyncio', - 'pytest_recording', - 'yandex_cloud_ml_sdk._testing.plugin', -] - @pytest.fixture(name='auth') def fixture_auth(request): diff --git a/tox.ini b/tox.ini index d59cde8..2b0df8f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{3.8,3.9,3.10,3.11,3.12},py{3.8,3.9,3.10,3.11,3.12}-extra-deps +envlist = py{3.8,3.9,3.10,3.11,3.12,3.13},py{3.8,3.9,3.10,3.11,3.12,3.13}-extra-deps [testenv] deps = @@ -15,11 +15,11 @@ commands = tests \ {posargs} -[testenv:py{3.8,3.9,3.10,3.11,3.12}-extra-deps] +[testenv:py{3.8,3.9,3.10,3.11,3.12,3.13}-extra-deps] deps = -r test_requirements.txt numpy - langchain-core<=0.2.41 + langchain-core>=0.3; python_version >= '3.9' commands = pytest \