diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index c3a3bbbda..e5a241de9 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -22,3 +22,4 @@ environs pydantic Type[VectorDB]: if self == DB.AWSOpenSearch: from .aws_opensearch.aws_opensearch import AWSOpenSearch return AWSOpenSearch + if self == DB.Clickhouse: + from .clickhouse.clickhouse import Clickhouse + return Clickhouse @property def config_cls(self) -> Type[DBConfig]: @@ -147,6 +151,10 @@ def config_cls(self) -> Type[DBConfig]: if self == DB.AWSOpenSearch: from .aws_opensearch.config import AWSOpenSearchConfig return AWSOpenSearchConfig + + if self == DB.Clickhouse: + from .clickhouse.config import ClickhouseConfig + return ClickhouseConfig def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]: if self == DB.Milvus: diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py new file mode 100644 index 000000000..2d9d0e0b5 --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -0,0 +1,140 @@ +"""Wrapper around the Clickhouse vector database over VectorDB""" + +import io +import logging +from contextlib import contextmanager +from typing import Any +import clickhouse_connect +import numpy as np + +from ..api import VectorDB, DBCaseConfig + +log = logging.getLogger(__name__) + +class Clickhouse(VectorDB): + """Use SQLAlchemy instructions""" + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: DBCaseConfig, + collection_name: str = "CkVectorCollection", + drop_old: bool = False, + **kwargs, + ): + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + + self._index_name = "pqvector_index" + self._primary_field = "id" + self._vector_field = "embedding" + + # construct basic units + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"]) + + if drop_old: + log.info(f"Clickhouse client drop table : {self.table_name}") + self._drop_table() + self._create_table(dim) + + self.conn.close() + self.conn = None + + @contextmanager + def init(self) -> None: + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"]) + + try: + yield + finally: + self.conn.close() + self.conn = None + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + + self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}') + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + + try: + # create table + self.conn.command( + f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \ + (id Integer, embedding Array(Float32)) ENGINE = MergeTree() ORDER BY id;' + ) + + except Exception as e: + log.warning( + f"Failed to create Clickhouse table: {self.table_name} error: {e}" + ) + raise e from None + + def ready_to_load(self): + pass + + def optimize(self): + pass + + def ready_to_search(self): + pass + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> (int, Exception): + assert self.conn is not None, "Connection is not initialized" + + try: + items = [] + for i, row in enumerate(metadata): + items.append((metadata[i], np.array(embeddings[i]).tolist())) + + self.conn.insert(self.table_name, items, ['id', 'embedding']) + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + + if filters: + gt = filters.get("id") + filterSql = f'SELECT id,cosineDistance(embedding,{query}) AS score FROM {self.db_config["dbname"]}.{self.table_name} \ + WHERE id > {gt} ORDER BY score LIMIT {k};' + result = self.conn.query(filterSql).result_rows + return [int(row[0]) for row in result] + else: + selectSql = f'SELECT id,cosineDistance(embedding,{query}) AS score FROM {self.db_config["dbname"]}.{self.table_name} \ + ORDER BY score LIMIT {k};' + result = self.conn.query(selectSql).result_rows + return [int(row[0]) for row in result] \ No newline at end of file diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py new file mode 100644 index 000000000..c132865f8 --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -0,0 +1,21 @@ +from typing import TypedDict +from pydantic import BaseModel, SecretStr +from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +class ClickhouseConfig(DBConfig): + user_name: SecretStr = "default" + password: SecretStr + host: str = "127.0.0.1" + port: int = 30193 + db_name: str = "default" + + def to_dict(self) -> dict: + user_str = self.user_name.get_secret_value() + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": user_str, + "password": pwd_str + }