Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hybrid search index #46

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions examples/async/assistants/hybrid_search_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

from __future__ import annotations

import asyncio
import pathlib
import pprint

from yandex_cloud_ml_sdk import AsyncYCloudML
from yandex_cloud_ml_sdk.search_indexes import (
HybridSearchIndexType, ReciprocalRankFusionIndexCombinationStrategy, StaticIndexChunkingStrategy,
TextSearchIndexType, VectorSearchIndexType
)


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')

file_coros = (
sdk.files.upload(
local_path(path),
ttl_days=5,
expiration_policy="static",
)
for path in ['turkey_example.txt', 'maldives_example.txt']
)
files = await asyncio.gather(*file_coros)

# How to create search index with all default settings:
operation = await sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType()
)
default_search_index = await operation.wait()
print("new hybrid search index with default settings:")
pprint.pprint(default_search_index)

# But you could override any default:
operation = await sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType(
chunking_strategy=StaticIndexChunkingStrategy(
max_chunk_size_tokens=700,
chunk_overlap_tokens=300
),
# you could also override some text/vector indexes settings
text_search_index=TextSearchIndexType(),
vector_search_index=VectorSearchIndexType(),
normalization_strategy='L2',
combination_strategy=ReciprocalRankFusionIndexCombinationStrategy(
k=10
vhaldemar marked this conversation as resolved.
Show resolved Hide resolved
)
)
)
search_index = await operation.wait()
print("new hybrid search index with overridden settings:")
pprint.pprint(search_index)

# And how to use your index you could learn in example file "assistant_with_search_index.py".
# Working with hybrid index does not differ from working with any other index besides creation.

# Created resources cleanup:
for file in files:
await file.delete()

for search_index in [default_search_index, search_index]:
print(f"delete {search_index.id=}")
await search_index.delete()


if __name__ == '__main__':
asyncio.run(main())
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ exclude-protected= [
"_to_proto",
"_result_type",
"_proto_result_type",
"_proto_field_name",
"_coerce",
]
valid-classmethod-first-arg="cls"
valid-metaclass-classmethod-first-arg="cls"
Expand Down
109 changes: 109 additions & 0 deletions src/yandex_cloud_ml_sdk/_search_indexes/combination_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# pylint: disable=no-name-in-module,protected-access
from __future__ import annotations

import abc
import enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Collection

from google.protobuf.wrappers_pb2 import Int64Value
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import CombinationStrategy as ProtoCombinationStrategy
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import MeanCombinationStrategy as ProtoMeanCombinationStrategy
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import (
ReciprocalRankFusionCombinationStrategy as ProtoReciprocalRankFusionCombinationStrategy
)

if TYPE_CHECKING:
from yandex_cloud_ml_sdk._sdk import BaseSDK


class BaseIndexCombinationStrategy(abc.ABC):
@classmethod
@abc.abstractmethod
def _from_proto(cls, proto: Any, sdk: BaseSDK) -> BaseIndexCombinationStrategy:
pass

@abc.abstractmethod
def _to_proto(self) -> ProtoCombinationStrategy:
pass

@classmethod
def _from_upper_proto(cls, proto: ProtoCombinationStrategy, sdk: BaseSDK) -> BaseIndexCombinationStrategy:
if proto.HasField('mean_combination'):
return MeanIndexCombinationStrategy._from_proto(
proto=proto.mean_combination,
sdk=sdk
)
if proto.HasField('rrf_combination'):
return ReciprocalRankFusionIndexCombinationStrategy._from_proto(
proto=proto.rrf_combination,
sdk=sdk
)
raise NotImplementedError(
'combination strategies other then Mean and RRF are not supported in this SDK version'
)


_orig = ProtoMeanCombinationStrategy.MeanEvaluationTechnique

class MeanIndexEvaluationTechnique(enum.IntEnum):
MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED = _orig.MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED
ARITHMETIC = _orig.ARITHMETIC
GEOMETRIC = _orig.GEOMETRIC
HARMONIC = _orig.HARMONIC

@classmethod
def _coerce(cls, technique: str | int ) -> MeanIndexEvaluationTechnique:
if isinstance(technique, str):
technique = _orig.Value(technique.upper())
return cls(technique)


@dataclass(frozen=True)
class MeanIndexCombinationStrategy(BaseIndexCombinationStrategy):
mean_evaluation_technique: MeanIndexEvaluationTechnique | None
weights: Collection[float] | None

@classmethod
# pylint: disable=unused-argument
def _from_proto(cls, proto: ProtoMeanCombinationStrategy, sdk: BaseSDK) -> MeanIndexCombinationStrategy:
return cls(
mean_evaluation_technique=MeanIndexEvaluationTechnique._coerce(proto.mean_evaluation_technique),
weights=tuple(proto.weights)
)

def _to_proto(self) -> ProtoCombinationStrategy:
kwargs: dict[str, Any] = {}
if self.mean_evaluation_technique:
kwargs['mean_evaluation_technique'] = int(self.mean_evaluation_technique)
if self.weights is not None:
kwargs['weghts'] = tuple(self.weights)

return ProtoCombinationStrategy(
mean_combination=ProtoMeanCombinationStrategy(**kwargs)
)


@dataclass(frozen=True)
class ReciprocalRankFusionIndexCombinationStrategy(BaseIndexCombinationStrategy):
k: int | None = None

@classmethod
# pylint: disable=unused-argument
def _from_proto(
cls, proto: ProtoReciprocalRankFusionCombinationStrategy, sdk: BaseSDK
) -> ReciprocalRankFusionIndexCombinationStrategy:
kwargs = {}
if proto.HasField('k'):
kwargs['k'] = proto.k.value
return ReciprocalRankFusionIndexCombinationStrategy(
**kwargs
)

def _to_proto(self) -> ProtoCombinationStrategy:
kwargs = {}
if self.k is not None:
kwargs['k'] = Int64Value(value=self.k)
return ProtoCombinationStrategy(
rrf_combination=ProtoReciprocalRankFusionCombinationStrategy(**kwargs)
)
19 changes: 7 additions & 12 deletions src/yandex_cloud_ml_sdk/_search_indexes/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import AsyncIterator, Generic, Iterator

from yandex.cloud.ai.assistants.v1.searchindex.search_index_pb2 import SearchIndex as ProtoSearchIndex
from yandex.cloud.ai.assistants.v1.searchindex.search_index_pb2 import TextSearchIndex, VectorSearchIndex
from yandex.cloud.ai.assistants.v1.searchindex.search_index_service_pb2 import (
CreateSearchIndexRequest, GetSearchIndexRequest, ListSearchIndicesRequest, ListSearchIndicesResponse
)
Expand All @@ -19,7 +18,7 @@
from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator

from .index_type import BaseSearchIndexType, TextSearchIndexType, VectorSearchIndexType
from .index_type import BaseSearchIndexType
from .search_index import AsyncSearchIndex, SearchIndex, SearchIndexTypeT


Expand Down Expand Up @@ -47,14 +46,11 @@ async def _create_deferred(

expiration_config = ExpirationConfig.coerce(ttl_days=ttl_days, expiration_policy=expiration_policy)

vector_search_index: VectorSearchIndex | None = None
text_search_index: TextSearchIndex | None = None
if isinstance(index_type, VectorSearchIndexType):
vector_search_index = index_type._to_proto()
elif isinstance(index_type, TextSearchIndexType):
text_search_index = index_type._to_proto()
elif is_defined(index_type):
raise TypeError('index type must be instance of SearchIndexType')
kwargs = {}
if is_defined(index_type):
if not isinstance(index_type, BaseSearchIndexType):
raise TypeError('index type must be instance of BaseSearchIndexType')
kwargs[index_type._proto_field_name] = index_type._to_proto()

request = CreateSearchIndexRequest(
folder_id=self._folder_id,
Expand All @@ -63,8 +59,7 @@ async def _create_deferred(
description=get_defined_value(description, ''),
labels=get_defined_value(labels, {}),
expiration_config=expiration_config.to_proto(),
vector_search_index=vector_search_index,
text_search_index=text_search_index,
**kwargs, # type: ignore[arg-type]
)

async with self._client.get_service_stub(SearchIndexServiceStub, timeout=timeout) as stub:
Expand Down
Loading
Loading