diff --git a/pyproject.toml b/pyproject.toml index 4da9ce2cc..3f517e9e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ all = [ "psycopg-binary", "opensearch-dsl==2.1.0", "opensearch-py==2.6.0", + "mariadb", ] qdrant = [ "qdrant-client" ] @@ -78,6 +79,7 @@ memorydb = [ "memorydb" ] chromadb = [ "chromadb" ] awsopensearch = [ "awsopensearch" ] zilliz_cloud = [] +mariadb = [ "mariadb" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 3e87e1fbe..c0405d94a 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -35,6 +35,7 @@ class DB(Enum): MemoryDB = "MemoryDB" Chroma = "Chroma" AWSOpenSearch = "OpenSearch" + MariaDB = "MariaDB" Test = "test" @@ -93,6 +94,10 @@ def init_cls(self) -> Type[VectorDB]: from .aws_opensearch.aws_opensearch import AWSOpenSearch return AWSOpenSearch + if self == DB.MariaDB: + from .mariadb.mariadb import MariaDB + return MariaDB + @property def config_cls(self) -> Type[DBConfig]: """Import while in use""" @@ -148,6 +153,10 @@ def config_cls(self) -> Type[DBConfig]: from .aws_opensearch.config import AWSOpenSearchConfig return AWSOpenSearchConfig + if self == DB.MariaDB: + from .mariadb.config import MariaDBConfig + return MariaDBConfig + def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]: if self == DB.Milvus: from .milvus.config import _milvus_case_config @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon from .pgvectorscale.config import _pgvectorscale_case_config return _pgvectorscale_case_config.get(index_type) + if self == DB.MariaDB: + from .mariadb.config import _mariadb_case_config + return _mariadb_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py new file mode 100644 index 000000000..3c333b724 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -0,0 +1,107 @@ +from typing import Annotated, Optional, Unpack + +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class MariaDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--username", + type=str, + help="Username", + required=True, + ), + ] + password: Annotated[ + str, click.option("--password", + type=str, + help="Password", + required=True, + ), + ] + + host: Annotated[ + str, click.option("--host", + type=str, + help="Db host", + default="127.0.0.1", + ), + ] + + port: Annotated[ + int, click.option("--port", + type=int, + default=3306, + help="Db Port", + ), + ] + + storage_engine: Annotated[ + int, click.option("--storage-engine", + type=click.Choice(["InnoDB", "MyISAM"]), + help="DB storage engine", + required=True, + ), + ] + +class MariaDBHNSWTypedDict(MariaDBTypedDict): + ... + m: Annotated[ + Optional[int], click.option("--m", + type=int, + help="MariaDB system variable mhnsw_max_edges_per_node", + required=False, + ), + ] + + ef_search: Annotated[ + Optional[int], click.option("--ef-search", + type=int, + help="MariaDB system variable mhnsw_min_limit", + required=False, + ), + ] + + cache_size: Annotated[ + Optional[int], click.option("--cache-size", + type=int, + help="MariaDB system variable mhnsw_cache_size", + required=False, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) +def MariaDBHNSW( + **parameters: Unpack[MariaDBHNSWTypedDict], +): + from .config import MariaDBConfig, MariaDBHNSWConfig + + run( + db=DB.MariaDB, + db_config=MariaDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=MariaDBHNSWConfig( + M=parameters["m"], + ef_search=parameters["ef_search"], + storage_engine=parameters["storage_engine"], + cache_size=parameters["cache_size"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py new file mode 100644 index 000000000..50b665a96 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -0,0 +1,71 @@ +from pydantic import SecretStr, BaseModel +from typing import TypedDict +from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +class MariaDBConfigDict(TypedDict): + """These keys will be directly used as kwargs in mariadb connection string, + so the names must match exactly mariadb API""" + + user: str + password: str + host: str + port: int + + +class MariaDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 3306 + + def to_dict(self) -> MariaDBConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + } + + +class MariaDBIndexConfig(BaseModel): + """Base config for milvus""" + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + elif self.metric_type == MetricType.COSINE: + return "cosine" + else: + raise ValueError(f"Metric type {self.metric_type} is not supported!") + +class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): + M: int | None + ef_search: int | None + index: IndexType = IndexType.HNSW + storage_engine: str = "InnoDB" + cache_size: int | None + + def index_param(self) -> dict: + return { + "storage_engine": self.storage_engine, + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "M": self.M, + "cache_size": self.cache_size, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "ef_search": self.ef_search, + } + + +_mariadb_case_config = { + IndexType.HNSW: MariaDBHNSWConfig, +} + + diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py new file mode 100644 index 000000000..e70ced1f6 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -0,0 +1,194 @@ +from ..api import VectorDB + +import logging +from contextlib import contextmanager +from typing import Any, Optional, Tuple +from ..api import VectorDB +from .config import MariaDBConfigDict, MariaDBIndexConfig +import numpy as np + +import mariadb + +log = logging.getLogger(__name__) + +class MariaDB(VectorDB): + def __init__( + self, + dim: int, + db_config: MariaDBConfigDict, + db_case_config: MariaDBIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): + + self.name = "MariaDB" + self.db_config = db_config + self.case_config = db_case_config + self.db_name = "vectordbbench" + self.table_name = collection_name + self.dim = dim + + # construct basic units + self.conn, self.cursor = self._create_connection(**self.db_config) + + if drop_old: + self._drop_db() + self._create_db_table(dim) + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + @staticmethod + def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]: + conn = mariadb.connect(**kwargs) + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + + def _drop_db(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop db : {self.db_name}") + + # flush tables before dropping database to avoid some locking issue + self.cursor.execute("FLUSH TABLES") + self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + def _create_db_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + log.info(f"{self.name} client create database : {self.db_name}") + self.cursor.execute(f"CREATE DATABASE {self.db_name}") + + log.info(f"{self.name} client create table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + self.cursor.execute(f""" + CREATE TABLE {self.table_name} ( + id INT PRIMARY KEY, + v BLOB NOT NULL, + VECTOR INDEX (v) + distance_function={index_param["metric_type"]} + ) ENGINE={index_param["storage_engine"]} + """) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning( + f"Failed to create table: {self.table_name} error: {e}" + ) + raise e from None + + + @contextmanager + def init(self) -> None: + """ create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn, self.cursor = self._create_connection(**self.db_config) + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + # maximize allowed package size + self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") + + if index_param["index_type"] == "HNSW": + if index_param["M"] != None: + self.cursor.execute(f"SET mhnsw_max_edges_per_node = {index_param["M"]}") + self.cursor.execute(f"SET GLOBAL mhnsw_max_edges_per_node = {index_param["M"]}") + if index_param["cache_size"] != None: + self.cursor.execute(f"SET GLOBAL mhnsw_cache_size = {index_param["cache_size"]}") + if search_param["ef_search"] != None: + self.cursor.execute(f"SET mhnsw_min_limit = {search_param["ef_search"]}") + self.cursor.execute(f"SET GLOBAL mhnsw_min_limit = {search_param["ef_search"]}") + self.cursor.execute("COMMIT") + + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" + self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + def ready_to_load(self) -> bool: + pass + + def optimize(self) -> None: + # TODO: Create index after loading data once MariaDB supports alter table to add Vector index + pass + + @staticmethod + def vector_to_hex(v): + return np.array(v, 'float32').tobytes() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Insert embeddings into the database. + Should call self.init() first. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + batch_data = [] + for i, row in enumerate(metadata_arr): + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))); + + self.cursor.executemany(self.insert_sql, batch_data) + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + return len(metadata), None + except Exception as e: + log.warning( + f"Failed to insert data into Vector table ({self.table_name}), error: {e}" + ) + return 0, e + + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> (list[int]): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + search_param = self.case_config.search_param() + + self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) + + return [id for id, in self.cursor.fetchall()] + diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index e62c25a3d..a763bb399 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -8,6 +8,7 @@ from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex from ..backend.clients.milvus.cli import MilvusAutoIndex from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.mariadb.cli import MariaDBHNSW from .cli import cli @@ -22,6 +23,7 @@ cli.add_command(MilvusAutoIndex) cli.add_command(AWSOpenSearch) cli.add_command(PgVectorScaleDiskAnn) +cli.add_command(MariaDBHNSW) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 78d1936d7..54cf6b489 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -788,6 +788,65 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_MariaDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_StorageEngine_MariaDB = CaseConfigInput( + label=CaseConfigParamType.storage_engine, + inputHelp="Select Storage Engine", + inputType=InputType.Option, + inputConfig={ + "options": ["InnoDB", "MyISAM"], + }, +) + +CaseConfigParamInput_M_MariaDB = CaseConfigInput( + label=CaseConfigParamType.M, + inputHelp="mhnsw_max_edges_per_node", + inputType=InputType.Number, + inputConfig={ + "min": 2, + "max": 200, + "value": 6, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_MariaDB = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputHelp="mhnsw_min_limit", + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 100, + "value": 20, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_CacheSize_MariaDB = CaseConfigInput( + label=CaseConfigParamType.cache_size, + inputHelp="mhnsw_cache_size", + inputType=InputType.Number, + inputConfig={ + "min": 1048576, + "max": (1 << 53) - 1, + "value": 16 * 1024 ** 3, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, CaseConfigParamInput_M, @@ -904,6 +963,21 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_query_search_list_size, ] +MariaDBLoadingConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, +] + +MariaDBPerformanceConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, + CaseConfigParamInput_EFSearch_MariaDB, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -932,4 +1006,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: PgVectorScaleLoadingConfig, CaseLabel.Performance: PgVectorScalePerformanceConfig, }, + DB.MariaDB: { + CaseLabel.Load: MariaDBLoadingConfig, + CaseLabel.Performance: MariaDBPerformanceConfig, + }, } diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 75b427091..0574e3c29 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -70,6 +70,8 @@ class CaseConfigParamType(Enum): num_bits_per_dimension = "num_bits_per_dimension" query_search_list_size = "query_search_list_size" query_rescore = "query_rescore" + storage_engine = "storage_engine" + cache_size = "cache_size" class CustomizedCase(BaseModel):