Skip to content

Commit

Permalink
Merge 0.1.2 to a master (#21)
Browse files Browse the repository at this point in the history
* Add typing_extensions dep

* Fix for langchain >= 0.3 (#19)

* Support for 3.13 (#20)

* Bump version

* Fix for mypy dataclass inheritance
  • Loading branch information
vhaldemar authored Nov 5, 2024
1 parent 8c4d8af commit 73dd7dc
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ]
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ]
env: [ '', '-extra-deps']
timeout-minutes: 10
steps:
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ For more usage examples look into `examples` folder.

## Langchain integration

To use langchain integration, install `yandex-cloud-ml-sdk` package with a `langchain` extra:

```sh
pip install yandex-cloud-ml-sdk[langchain]
```

Usage example:

```python
from yandex_cloud_ml_sdk import YCloudML
from langchain_core.messages import AIMessage, HumanMessage
Expand Down
30 changes: 30 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import pathlib
import sys

pytest_plugins = [
'pytest_asyncio',
'pytest_recording',
'yandex_cloud_ml_sdk._testing.plugin',
]

langchain_paths = [
'examples/langchain/',
'tests/langchain_',
'src/yandex_cloud_ml_sdk/_types/langchain.py',
'src/yandex_cloud_ml_sdk/_utils/langchain.py',
'src/yandex_cloud_ml_sdk/_models/completions/langchain.py',
]

def pytest_ignore_collect(collection_path, path, config): # pylint: disable=unused-argument
if sys.version_info > (3, 9):
return None

base_path = pathlib.Path(__file__).parent
for suffix in langchain_paths:
path_to_ignore = base_path / suffix
if str(collection_path).startswith(str(path_to_ignore)):
return True

return None
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Operating System :: OS Independent",
"Operating System :: POSIX",
"Operating System :: MacOS",
Expand All @@ -46,7 +47,7 @@ test = [
"tox",
]
langchain = [
"langchain-core>=0.2.29,<=0.2.41"
"langchain-core>=0.3"
]

[project.urls]
Expand Down Expand Up @@ -95,6 +96,9 @@ asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
"ignore::DeprecationWarning:pytest_pylint"
]
addopts = [
"--import-mode=importlib",
]

[tool.pylint]
ignore="CVS"
Expand Down Expand Up @@ -263,7 +267,15 @@ overgeneral-exceptions= [
[[tool.mypy.overrides]]
module = [
"get_annotations",
"scipy.spatial.distance",
"langchain_core",
"langchain_core.callbacks",
"langchain_core.callbacks.manager",
"langchain_core.language_models.base",
"langchain_core.language_models.chat_models",
"langchain_core.messages",
"langchain_core.messages.ai",
"langchain_core.outputs",
"pydantic",
"scipy.spatial.distance",
]
ignore_missing_imports = true
12 changes: 6 additions & 6 deletions src/yandex_cloud_ml_sdk/_models/completions/langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import asdict
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar
from typing import Any, AsyncIterator, Iterator, TypeVar

from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -14,12 +14,9 @@
from yandex_cloud_ml_sdk._utils.sync import run_sync_generator_impl, run_sync_impl

from .message import TextMessageDict
from .model import BaseGPTModel # pylint: disable=cyclic-import
from .result import Alternative, AlternativeStatus, GPTModelResult

if TYPE_CHECKING:
from .model import BaseGPTModel # noqa


GenerationClassT = TypeVar('GenerationClassT', bound=ChatGeneration)


Expand Down Expand Up @@ -53,7 +50,7 @@ def _transform_messages(history: list[BaseMessage]) -> list[TextMessageDict]:
return chat_history


class ChatYandexGPT(BaseYandexLanguageModel['BaseGPTModel'], BaseChatModel):
class ChatYandexGPT(BaseYandexLanguageModel[BaseGPTModel], BaseChatModel):
class Config:
arbitrary_types_allowed = True

Expand Down Expand Up @@ -182,3 +179,6 @@ async def _astream(
text_override=delta,
)
yield generation


ChatYandexGPT.model_rebuild()
23 changes: 13 additions & 10 deletions src/yandex_cloud_ml_sdk/_runs/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import abc
import dataclasses
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun
from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent
Expand All @@ -17,13 +17,19 @@
if TYPE_CHECKING:
from yandex_cloud_ml_sdk._sdk import BaseSDK

StatusTypeT = TypeVar('StatusTypeT', bound=BaseRunStatus)
MessageTypeT = TypeVar('MessageTypeT', bound=BaseMessage)


@dataclasses.dataclass(frozen=True)
class BaseRunResult(BaseRunStatus, BaseResult[ProtoResultTypeT]):
status: BaseRunStatus
class BaseRunResult(
BaseRunStatus,
BaseResult[ProtoResultTypeT],
Generic[ProtoResultTypeT, StatusTypeT, MessageTypeT]
):
status: StatusTypeT
error: str | None
_message: BaseMessage | None
_message: MessageTypeT | None

@classmethod
@abc.abstractmethod
Expand All @@ -43,7 +49,7 @@ def is_failed(self) -> bool:
return self.status.is_failed

@property
def message(self) -> BaseMessage:
def message(self) -> MessageTypeT:
if self.is_failed:
raise ValueError("run is failed and don't have a message result")
assert self._message
Expand All @@ -59,11 +65,9 @@ def parts(self) -> tuple[Any]:


@dataclasses.dataclass(frozen=True)
class RunResult(BaseRunResult[ProtoRun]):
class RunResult(BaseRunResult[ProtoRun, RunStatus, Message]):
_proto_result_type = ProtoRun

_message: Message | None
status: RunStatus
usage: Usage | None

@classmethod
Expand Down Expand Up @@ -100,9 +104,8 @@ def _from_proto(cls, proto: ProtoRun, sdk: BaseSDK) -> RunResult:


@dataclasses.dataclass(frozen=True)
class RunStreamEvent(BaseRunResult[ProtoStreamEvent]):
class RunStreamEvent(BaseRunResult[ProtoStreamEvent, StreamEvent, BaseMessage]):
_proto_result_type = ProtoStreamEvent
status: StreamEvent

@classmethod
def _from_proto(cls, proto: ProtoStreamEvent, sdk: BaseSDK) -> RunStreamEvent:
Expand Down
19 changes: 12 additions & 7 deletions src/yandex_cloud_ml_sdk/_types/langchain.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
# pylint: disable=abstract-method
# pylint: disable=abstract-method,wrong-import-position
from __future__ import annotations

import sys

if sys.version_info < (3, 9):
raise NotImplementedError("Langchain integration doesn't supported for python<3.9")

from typing import Generic, TypeVar

from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.load import Serializable
from langchain_core.pydantic_v1 import BaseModel as LangchainModel
from pydantic import BaseModel as PydanticModel
from pydantic import ConfigDict

from yandex_cloud_ml_sdk._types.model import BaseModel

ModelTypeT = TypeVar('ModelTypeT', bound=BaseModel)


class BaseYandexModel(LangchainModel, Generic[ModelTypeT]):
class BaseYandexModel(PydanticModel, Generic[ModelTypeT]):
ycmlsdk_model: ModelTypeT
timeout: int = 60

class Config(Serializable.Config):
arbitrary_types_allowed = True

model_config = ConfigDict(
arbitrary_types_allowed=True
)


class BaseYandexLanguageModel(BaseYandexModel[ModelTypeT], BaseLanguageModel):
Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
)
from yandex_cloud_ml_sdk._types.misc import UNDEFINED

pytest_plugins = [
'pytest_asyncio',
'pytest_recording',
'yandex_cloud_ml_sdk._testing.plugin',
]


@pytest.fixture(name='auth')
def fixture_auth(request):
Expand Down
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = py{3.8,3.9,3.10,3.11,3.12},py{3.8,3.9,3.10,3.11,3.12}-extra-deps
envlist = py{3.8,3.9,3.10,3.11,3.12,3.13},py{3.8,3.9,3.10,3.11,3.12,3.13}-extra-deps

[testenv]
deps =
Expand All @@ -15,11 +15,11 @@ commands =
tests \
{posargs}

[testenv:py{3.8,3.9,3.10,3.11,3.12}-extra-deps]
[testenv:py{3.8,3.9,3.10,3.11,3.12,3.13}-extra-deps]
deps =
-r test_requirements.txt
numpy
langchain-core<=0.2.41
langchain-core>=0.3; python_version >= '3.9'

commands =
pytest \
Expand Down

0 comments on commit 73dd7dc

Please sign in to comment.