-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Cai Yudong <[email protected]>
- Loading branch information
Showing
7 changed files
with
405 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
159 changes: 159 additions & 0 deletions
159
vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import logging | ||
from contextlib import contextmanager | ||
import time | ||
from typing import Iterable, Type | ||
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType | ||
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig | ||
from opensearchpy import OpenSearch | ||
from opensearchpy.helpers import bulk | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class AWSOpenSearch(VectorDB): | ||
def __init__( | ||
self, | ||
dim: int, | ||
db_config: dict, | ||
db_case_config: AWSOpenSearchIndexConfig, | ||
index_name: str = "vdb_bench_index", # must be lowercase | ||
id_col_name: str = "id", | ||
vector_col_name: str = "embedding", | ||
drop_old: bool = False, | ||
**kwargs, | ||
): | ||
self.dim = dim | ||
self.db_config = db_config | ||
self.case_config = db_case_config | ||
self.index_name = index_name | ||
self.id_col_name = id_col_name | ||
self.category_col_names = [ | ||
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000] | ||
] | ||
self.vector_col_name = vector_col_name | ||
|
||
log.info(f"AWS_OpenSearch client config: {self.db_config}") | ||
client = OpenSearch(**self.db_config) | ||
if drop_old: | ||
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}") | ||
is_existed = client.indices.exists(index=self.index_name) | ||
if is_existed: | ||
client.indices.delete(index=self.index_name) | ||
self._create_index(client) | ||
|
||
@classmethod | ||
def config_cls(cls) -> AWSOpenSearchConfig: | ||
return AWSOpenSearchConfig | ||
|
||
@classmethod | ||
def case_config_cls( | ||
cls, index_type: IndexType | None = None | ||
) -> AWSOpenSearchIndexConfig: | ||
return AWSOpenSearchIndexConfig | ||
|
||
def _create_index(self, client: OpenSearch): | ||
settings = { | ||
"index": { | ||
"knn": True, | ||
# "number_of_shards": 5, | ||
# "refresh_interval": "600s", | ||
} | ||
} | ||
mappings = { | ||
"properties": { | ||
self.id_col_name: {"type": "integer"}, | ||
**{ | ||
categoryCol: {"type": "keyword"} | ||
for categoryCol in self.category_col_names | ||
}, | ||
self.vector_col_name: { | ||
"type": "knn_vector", | ||
"dimension": self.dim, | ||
"method": self.case_config.index_param(), | ||
}, | ||
} | ||
} | ||
try: | ||
client.indices.create( | ||
index=self.index_name, body=dict(settings=settings, mappings=mappings) | ||
) | ||
except Exception as e: | ||
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}") | ||
raise e from None | ||
|
||
@contextmanager | ||
def init(self) -> None: | ||
"""connect to elasticsearch""" | ||
self.client = OpenSearch(**self.db_config) | ||
|
||
yield | ||
# self.client.transport.close() | ||
self.client = None | ||
del self.client | ||
|
||
def insert_embeddings( | ||
self, | ||
embeddings: Iterable[list[float]], | ||
metadata: list[int], | ||
**kwargs, | ||
) -> tuple[int, Exception]: | ||
"""Insert the embeddings to the elasticsearch.""" | ||
assert self.client is not None, "should self.init() first" | ||
|
||
insert_data = [] | ||
for i in range(len(embeddings)): | ||
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}}) | ||
insert_data.append({self.vector_col_name: embeddings[i]}) | ||
try: | ||
resp = self.client.bulk(insert_data) | ||
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}") | ||
resp = self.client.indices.stats(self.index_name) | ||
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}") | ||
return (len(embeddings), None) | ||
except Exception as e: | ||
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}") | ||
time.sleep(10) | ||
return self.insert_embeddings(embeddings, metadata) | ||
|
||
def search_embedding( | ||
self, | ||
query: list[float], | ||
k: int = 100, | ||
filters: dict | None = None, | ||
) -> list[int]: | ||
"""Get k most similar embeddings to query vector. | ||
Args: | ||
query(list[float]): query embedding to look up documents similar to. | ||
k(int): Number of most similar embeddings to return. Defaults to 100. | ||
filters(dict, optional): filtering expression to filter the data while searching. | ||
Returns: | ||
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding. | ||
""" | ||
assert self.client is not None, "should self.init() first" | ||
|
||
body = { | ||
"size": k, | ||
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}}, | ||
} | ||
try: | ||
resp = self.client.search(index=self.index_name, body=body) | ||
log.info(f'Search took: {resp["took"]}') | ||
log.info(f'Search shards: {resp["_shards"]}') | ||
log.info(f'Search hits total: {resp["hits"]["total"]}') | ||
result = [int(d["_id"]) for d in resp["hits"]["hits"]] | ||
# log.info(f'success! length={len(res)}') | ||
|
||
return result | ||
except Exception as e: | ||
log.warning(f"Failed to search: {self.index_name} error: {str(e)}") | ||
raise e from None | ||
|
||
def optimize(self): | ||
"""optimize will be called between insertion and search in performance cases.""" | ||
pass | ||
|
||
def ready_to_load(self): | ||
"""ready_to_load will be called before load in load cases.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Annotated, TypedDict, Unpack | ||
|
||
import click | ||
from pydantic import SecretStr | ||
|
||
from ....cli.cli import ( | ||
CommonTypedDict, | ||
HNSWFlavor2, | ||
cli, | ||
click_parameter_decorators_from_typed_dict, | ||
run, | ||
) | ||
from .. import DB | ||
|
||
|
||
class AWSOpenSearchTypedDict(TypedDict): | ||
host: Annotated[ | ||
str, click.option("--host", type=str, help="Db host", required=True) | ||
] | ||
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")] | ||
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")] | ||
password: Annotated[str, click.option("--password", type=str, help="Db password")] | ||
|
||
|
||
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): | ||
... | ||
|
||
|
||
@cli.command() | ||
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict) | ||
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): | ||
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig | ||
run( | ||
db=DB.AWSOpenSearch, | ||
db_config=AWSOpenSearchConfig( | ||
host=parameters["host"], | ||
port=parameters["port"], | ||
user=parameters["user"], | ||
password=SecretStr(parameters["password"]), | ||
), | ||
db_case_config=AWSOpenSearchIndexConfig( | ||
), | ||
**parameters, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from enum import Enum | ||
from pydantic import SecretStr, BaseModel | ||
|
||
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType | ||
|
||
|
||
class AWSOpenSearchConfig(DBConfig, BaseModel): | ||
host: str = ( | ||
"xxxxxx.us-west-2.es.amazonaws.com" | ||
) | ||
port: int = 443 | ||
user: str = "admin" | ||
password: SecretStr = "xxxxxx" | ||
|
||
def to_dict(self) -> dict: | ||
return { | ||
"hosts": [{'host': self.host, 'port': self.port}], | ||
"http_auth": (self.user, self.password.get_secret_value()), | ||
"use_ssl": True, | ||
"http_compress": True, | ||
"verify_certs": True, | ||
"ssl_assert_hostname": False, | ||
"ssl_show_warn": False, | ||
"timeout": 600, | ||
} | ||
|
||
|
||
class AWSOS_Engine(Enum): | ||
nmslib = "nmslib" | ||
faiss = "faiss" | ||
lucene = "Lucene" | ||
|
||
|
||
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig): | ||
metric_type: MetricType = MetricType.L2 | ||
engine: AWSOS_Engine = AWSOS_Engine.nmslib | ||
efConstruction: int = 360 | ||
M: int = 30 | ||
|
||
def parse_metric(self) -> str: | ||
if self.metric_type == MetricType.IP: | ||
return "innerproduct" # only support faiss / nmslib, not for Lucene. | ||
elif self.metric_type == MetricType.COSINE: | ||
return "cosinesimil" | ||
return "l2" | ||
|
||
def index_param(self) -> dict: | ||
params = { | ||
"name": "hnsw", | ||
"space_type": self.parse_metric(), | ||
"engine": self.engine.value, | ||
"parameters": { | ||
"ef_construction": self.efConstruction, | ||
"m": self.M | ||
} | ||
} | ||
return params | ||
|
||
def search_param(self) -> dict: | ||
return {} |
Oops, something went wrong.