-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
340eb41
commit afc8fa8
Showing
4 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import logging | ||
import time | ||
from contextlib import contextmanager | ||
from typing import Iterable | ||
from ..api import VectorDB | ||
from .config import AliyunElasticsearchIndexConfig | ||
from elasticsearch.helpers import bulk | ||
|
||
|
||
for logger in ("elasticsearch", "elastic_transport"): | ||
logging.getLogger(logger).setLevel(logging.WARNING) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
class AliyunElasticsearch(VectorDB): | ||
def __init__( | ||
self, | ||
dim: int, | ||
db_config: dict, | ||
db_case_config: AliyunElasticsearchIndexConfig, | ||
indice: str = "vdb_bench_indice", # must be lowercase | ||
id_col_name: str = "id", | ||
vector_col_name: str = "vector", | ||
drop_old: bool = False, | ||
**kwargs, | ||
): | ||
self.dim = dim | ||
self.db_config = db_config | ||
self.case_config = db_case_config | ||
self.indice = indice | ||
self.id_col_name = id_col_name | ||
self.vector_col_name = vector_col_name | ||
|
||
from elasticsearch import Elasticsearch | ||
|
||
client = Elasticsearch(**self.db_config) | ||
|
||
if drop_old: | ||
log.info(f"Elasticsearch client drop_old indices: {self.indice}") | ||
is_existed_res = client.indices.exists(index=self.indice) | ||
if is_existed_res.raw: | ||
client.indices.delete(index=self.indice) | ||
self._create_indice(client) | ||
|
||
@contextmanager | ||
def init(self) -> None: | ||
"""connect to elasticsearch""" | ||
from elasticsearch import Elasticsearch | ||
self.client = Elasticsearch(**self.db_config, request_timeout=180) | ||
|
||
yield | ||
# self.client.transport.close() | ||
self.client = None | ||
del(self.client) | ||
|
||
def _create_indice(self, client) -> None: | ||
mappings = { | ||
"_source": {"excludes": [self.vector_col_name]}, | ||
"properties": { | ||
self.id_col_name: {"type": "integer", "store": True}, | ||
self.vector_col_name: { | ||
"dims": self.dim, | ||
**self.case_config.index_param(), | ||
}, | ||
} | ||
} | ||
|
||
try: | ||
client.indices.create(index=self.indice, mappings=mappings) | ||
except Exception as e: | ||
log.warning(f"Failed to create indice: {self.indice} error: {str(e)}") | ||
raise e from None | ||
|
||
def insert_embeddings( | ||
self, | ||
embeddings: Iterable[list[float]], | ||
metadata: list[int], | ||
**kwargs, | ||
) -> (int, Exception): | ||
"""Insert the embeddings to the elasticsearch.""" | ||
assert self.client is not None, "should self.init() first" | ||
|
||
insert_data = [ | ||
{ | ||
"_index": self.indice, | ||
"_source": { | ||
self.id_col_name: metadata[i], | ||
self.vector_col_name: embeddings[i], | ||
}, | ||
} | ||
for i in range(len(embeddings)) | ||
] | ||
try: | ||
bulk_insert_res = bulk(self.client, insert_data) | ||
return (bulk_insert_res[0], None) | ||
except Exception as e: | ||
log.warning(f"Failed to insert data: {self.indice} error: {str(e)}") | ||
return (0, e) | ||
|
||
def search_embedding( | ||
self, | ||
query: list[float], | ||
k: int = 100, | ||
filters: dict | None = None, | ||
) -> list[int]: | ||
"""Get k most similar embeddings to query vector. | ||
Args: | ||
query(list[float]): query embedding to look up documents similar to. | ||
k(int): Number of most similar embeddings to return. Defaults to 100. | ||
filters(dict, optional): filtering expression to filter the data while searching. | ||
Returns: | ||
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding. | ||
""" | ||
assert self.client is not None, "should self.init() first" | ||
# is_existed_res = self.client.indices.exists(index=self.indice) | ||
# assert is_existed_res.raw == True, "should self.init() first" | ||
|
||
knn = { | ||
"field": self.vector_col_name, | ||
"k": k, | ||
"num_candidates": self.case_config.num_candidates, | ||
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] | ||
if filters | ||
else [], | ||
"query_vector": query, | ||
} | ||
size = k | ||
try: | ||
res = self.client.search( | ||
index=self.indice, | ||
knn=knn, | ||
size=size, | ||
_source=False, | ||
docvalue_fields=[self.id_col_name], | ||
stored_fields="_none_", | ||
filter_path=[f"hits.hits.fields.{self.id_col_name}"], | ||
) | ||
res = [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]] | ||
|
||
return res | ||
except Exception as e: | ||
log.warning(f"Failed to search: {self.indice} error: {str(e)}") | ||
raise e from None | ||
|
||
def optimize(self): | ||
"""optimize will be called between insertion and search in performance cases.""" | ||
assert self.client is not None, "should self.init() first" | ||
self.client.indices.refresh(index=self.indice) | ||
force_merge_task_id = self.client.indices.forcemerge(index=self.indice, max_num_segments=1, wait_for_completion=False)['task'] | ||
log.info(f"Elasticsearch force merge task id: {force_merge_task_id}") | ||
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 | ||
while True: | ||
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC) | ||
task_status = self.client.tasks.get(task_id=force_merge_task_id) | ||
if task_status['completed']: | ||
return | ||
|
||
def ready_to_load(self): | ||
"""ready_to_load will be called before load in load cases.""" | ||
pass |
60 changes: 60 additions & 0 deletions
60
vectordb_bench/backend/clients/aliyun_elasticsearch/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from enum import Enum | ||
from pydantic import SecretStr, BaseModel | ||
|
||
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType | ||
|
||
|
||
class AliyunElasticsearchConfig(DBConfig, BaseModel): | ||
#: Protocol in use to connect to the node | ||
scheme: str = "http" | ||
host: str = "" | ||
port: int = 9200 | ||
user: str = "elastic" | ||
password: SecretStr | ||
|
||
def to_dict(self) -> dict: | ||
return { | ||
"hosts": [{'scheme': self.scheme, 'host': self.host, 'port': self.port}], | ||
"basic_auth": (self.user, self.password.get_secret_value()), | ||
} | ||
|
||
|
||
class ESElementType(str, Enum): | ||
float = "float" # 4 byte | ||
byte = "byte" # 1 byte, -128 to 127 | ||
|
||
|
||
class AliyunElasticsearchIndexConfig(BaseModel, DBCaseConfig): | ||
element_type: ESElementType = ESElementType.float | ||
index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw' | ||
|
||
metric_type: MetricType | None = None | ||
efConstruction: int | None = None | ||
M: int | None = None | ||
num_candidates: int | None = None | ||
|
||
def parse_metric(self) -> str: | ||
if self.metric_type == MetricType.L2: | ||
return "l2_norm" | ||
elif self.metric_type == MetricType.IP: | ||
return "dot_product" | ||
return "cosine" | ||
|
||
def index_param(self) -> dict: | ||
params = { | ||
"type": "dense_vector", | ||
"index": True, | ||
"element_type": self.element_type.value, | ||
"similarity": self.parse_metric(), | ||
"index_options": { | ||
"type": self.index.value, | ||
"m": self.M, | ||
"ef_construction": self.efConstruction, | ||
}, | ||
} | ||
return params | ||
|
||
def search_param(self) -> dict: | ||
return { | ||
"num_candidates": self.num_candidates, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters