From 02b43b04a29917f00c4f9941f71693acbd228e68 Mon Sep 17 00:00:00 2001 From: Hugo Wen Date: Fri, 18 Oct 2024 10:38:13 -0700 Subject: [PATCH] Support MariaDB database MariaDB introduced vector support in version 11.7, enabling MariaDB Server to function as a relational vector database. https://mariadb.com/kb/en/vectors/ Now add support for MariaDB server, verified against MariaDB server of version 11.7.1: - Support MariaDB vector search with HNSW algorithm. - Support index and search parameters: - storage_engine: InnoDB or MyISAM - M: M parameter in MHNSW vector indexing - ef_search: minimal number of result candidates to look for in the vector index for ORDER BY ... LIMIT N queries. - max_cache_size: Upper limit for one MHNSW vector index cache - Support CLI of `vectordbbench mariadbhnsw`. --- pyproject.toml | 2 + vectordb_bench/backend/clients/__init__.py | 13 ++ vectordb_bench/backend/clients/mariadb/cli.py | 107 +++++++++ .../backend/clients/mariadb/config.py | 71 ++++++ .../backend/clients/mariadb/mariadb.py | 210 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 2 + .../frontend/config/dbCaseConfigs.py | 77 +++++++ vectordb_bench/models.py | 2 + 8 files changed, 484 insertions(+) create mode 100644 vectordb_bench/backend/clients/mariadb/cli.py create mode 100644 vectordb_bench/backend/clients/mariadb/config.py create mode 100644 vectordb_bench/backend/clients/mariadb/mariadb.py diff --git a/pyproject.toml b/pyproject.toml index 6ad8e23e..89b864b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ all = [ "opensearch-dsl", "opensearch-py", "memorydb", + "mariadb", ] qdrant = [ "qdrant-client" ] @@ -81,6 +82,7 @@ redis = [ "redis" ] memorydb = [ "memorydb" ] chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] +mariadb = [ "mariadb" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index e1b66a81..978f31cd 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -38,6 +38,7 @@ class DB(Enum): Chroma = "Chroma" AWSOpenSearch = "OpenSearch" AliyunElasticsearch = "AliyunElasticsearch" + MariaDB = "MariaDB" Test = "test" AliyunOpenSearch = "AliyunOpenSearch" @@ -113,6 +114,10 @@ def init_cls(self) -> Type[VectorDB]: from .aliyun_opensearch.aliyun_opensearch import AliyunOpenSearch return AliyunOpenSearch + if self == DB.MariaDB: + from .mariadb.mariadb import MariaDB + return MariaDB + @property def config_cls(self) -> Type[DBConfig]: """Import while in use""" @@ -184,6 +189,10 @@ def config_cls(self) -> Type[DBConfig]: from .aliyun_opensearch.config import AliyunOpenSearchConfig return AliyunOpenSearchConfig + if self == DB.MariaDB: + from .mariadb.config import MariaDBConfig + return MariaDBConfig + def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]: if self == DB.Milvus: from .milvus.config import _milvus_case_config @@ -237,6 +246,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon from .aliyun_opensearch.config import AliyunOpenSearchIndexConfig return AliyunOpenSearchIndexConfig + if self == DB.MariaDB: + from .mariadb.config import _mariadb_case_config + return _mariadb_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py new file mode 100644 index 00000000..c5439f37 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -0,0 +1,107 @@ +from typing import Annotated, Optional, Unpack + +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class MariaDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--username", + type=str, + help="Username", + required=True, + ), + ] + password: Annotated[ + str, click.option("--password", + type=str, + help="Password", + required=True, + ), + ] + + host: Annotated[ + str, click.option("--host", + type=str, + help="Db host", + default="127.0.0.1", + ), + ] + + port: Annotated[ + int, click.option("--port", + type=int, + default=3306, + help="Db Port", + ), + ] + + storage_engine: Annotated[ + int, click.option("--storage-engine", + type=click.Choice(["InnoDB", "MyISAM"]), + help="DB storage engine", + required=True, + ), + ] + +class MariaDBHNSWTypedDict(MariaDBTypedDict): + ... + m: Annotated[ + Optional[int], click.option("--m", + type=int, + help="M parameter in MHNSW vector indexing", + required=False, + ), + ] + + ef_search: Annotated[ + Optional[int], click.option("--ef-search", + type=int, + help="MariaDB system variable mhnsw_min_limit", + required=False, + ), + ] + + max_cache_size: Annotated[ + Optional[int], click.option("--max-cache-size", + type=int, + help="MariaDB system variable mhnsw_max_cache_size", + required=False, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) +def MariaDBHNSW( + **parameters: Unpack[MariaDBHNSWTypedDict], +): + from .config import MariaDBConfig, MariaDBHNSWConfig + + run( + db=DB.MariaDB, + db_config=MariaDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=MariaDBHNSWConfig( + M=parameters["m"], + ef_search=parameters["ef_search"], + storage_engine=parameters["storage_engine"], + max_cache_size=parameters["max_cache_size"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py new file mode 100644 index 00000000..c7b2cd5f --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -0,0 +1,71 @@ +from pydantic import SecretStr, BaseModel +from typing import TypedDict +from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +class MariaDBConfigDict(TypedDict): + """These keys will be directly used as kwargs in mariadb connection string, + so the names must match exactly mariadb API""" + + user: str + password: str + host: str + port: int + + +class MariaDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 3306 + + def to_dict(self) -> MariaDBConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + } + + +class MariaDBIndexConfig(BaseModel): + """Base config for MariaDB""" + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + elif self.metric_type == MetricType.COSINE: + return "cosine" + else: + raise ValueError(f"Metric type {self.metric_type} is not supported!") + +class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): + M: int | None + ef_search: int | None + index: IndexType = IndexType.HNSW + storage_engine: str = "InnoDB" + max_cache_size: int | None + + def index_param(self) -> dict: + return { + "storage_engine": self.storage_engine, + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "M": self.M, + "max_cache_size": self.max_cache_size, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "ef_search": self.ef_search, + } + + +_mariadb_case_config = { + IndexType.HNSW: MariaDBHNSWConfig, +} + + diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py new file mode 100644 index 00000000..52c7d7ec --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -0,0 +1,210 @@ +from ..api import VectorDB + +import logging +from contextlib import contextmanager +from typing import Any, Optional, Tuple +from ..api import VectorDB +from .config import MariaDBConfigDict, MariaDBIndexConfig +import numpy as np + +import mariadb + +log = logging.getLogger(__name__) + +class MariaDB(VectorDB): + def __init__( + self, + dim: int, + db_config: MariaDBConfigDict, + db_case_config: MariaDBIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): + + self.name = "MariaDB" + self.db_config = db_config + self.case_config = db_case_config + self.db_name = "vectordbbench" + self.table_name = collection_name + self.dim = dim + + # construct basic units + self.conn, self.cursor = self._create_connection(**self.db_config) + + if drop_old: + self._drop_db() + self._create_db_table(dim) + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + @staticmethod + def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]: + conn = mariadb.connect(**kwargs) + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + + def _drop_db(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop db : {self.db_name}") + + # flush tables before dropping database to avoid some locking issue + self.cursor.execute("FLUSH TABLES") + self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + def _create_db_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + log.info(f"{self.name} client create database : {self.db_name}") + self.cursor.execute(f"CREATE DATABASE {self.db_name}") + + log.info(f"{self.name} client create table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + + self.cursor.execute(f""" + CREATE TABLE {self.table_name} ( + id INT PRIMARY KEY, + v VECTOR({self.dim}) NOT NULL + ) ENGINE={index_param["storage_engine"]} + """) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning( + f"Failed to create table: {self.table_name} error: {e}" + ) + raise e from None + + + @contextmanager + def init(self) -> None: + """ create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn, self.cursor = self._create_connection(**self.db_config) + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + # maximize allowed package size + self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") + + if index_param["index_type"] == "HNSW": + if index_param["max_cache_size"] != None: + self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}") + if search_param["ef_search"] != None: + self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}") + self.cursor.execute("COMMIT") + + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" + self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + def ready_to_load(self) -> bool: + pass + + def optimize(self) -> None: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + index_options = f"DISTANCE={index_param['metric_type']}" + if index_param["index_type"] == "HNSW" and index_param["M"] != None: + index_options += f" M={index_param['M']}" + + self.cursor.execute(f""" + ALTER TABLE {self.db_name}.{self.table_name} + ADD VECTOR KEY v(v) {index_options} + """) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning( + f"Failed to create index: {self.table_name} error: {e}" + ) + raise e from None + + pass + + @staticmethod + def vector_to_hex(v): + return np.array(v, 'float32').tobytes() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Insert embeddings into the database. + Should call self.init() first. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + batch_data = [] + for i, row in enumerate(metadata_arr): + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))); + + self.cursor.executemany(self.insert_sql, batch_data) + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + return len(metadata), None + except Exception as e: + log.warning( + f"Failed to insert data into Vector table ({self.table_name}), error: {e}" + ) + return 0, e + + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> (list[int]): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + search_param = self.case_config.search_param() + + self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) + + return [id for id, in self.cursor.fetchall()] + diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index f9ad69ce..c54aa5b1 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -10,6 +10,7 @@ from ..backend.clients.milvus.cli import MilvusAutoIndex from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.alloydb.cli import AlloyDBScaNN +from ..backend.clients.mariadb.cli import MariaDBHNSW from .cli import cli @@ -26,6 +27,7 @@ cli.add_command(PgVectorScaleDiskAnn) cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) +cli.add_command(MariaDBHNSW) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 7f076a2d..0ca28f7a 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1079,6 +1079,64 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_MariaDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_StorageEngine_MariaDB = CaseConfigInput( + label=CaseConfigParamType.storage_engine, + inputHelp="Select Storage Engine", + inputType=InputType.Option, + inputConfig={ + "options": ["InnoDB", "MyISAM"], + }, +) + +CaseConfigParamInput_M_MariaDB = CaseConfigInput( + label=CaseConfigParamType.M, + inputHelp="M parameter in MHNSW vector indexing", + inputType=InputType.Number, + inputConfig={ + "min": 3, + "max": 200, + "value": 6, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_MariaDB = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputHelp="mhnsw_ef_search", + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 10000, + "value": 20, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_CacheSize_MariaDB = CaseConfigInput( + label=CaseConfigParamType.max_cache_size, + inputHelp="mhnsw_max_cache_size", + inputType=InputType.Number, + inputConfig={ + "min": 1048576, + "max": (1 << 53) - 1, + "value": 16 * 1024 ** 3, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) MilvusLoadConfig = [ CaseConfigParamInput_IndexType, @@ -1257,6 +1315,21 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_NumCandidates_AliES, ] +MariaDBLoadingConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, +] + +MariaDBPerformanceConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, + CaseConfigParamInput_EFSearch_MariaDB, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1305,4 +1378,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: AliyunOpensearchLoadingConfig, CaseLabel.Performance: AliyunOpenSearchPerformanceConfig, }, + DB.MariaDB: { + CaseLabel.Load: MariaDBLoadingConfig, + CaseLabel.Performance: MariaDBPerformanceConfig, + }, } diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 648fb172..52c1e8e5 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -85,6 +85,8 @@ class CaseConfigParamType(Enum): preReorderingNumNeigbors = "pre_reordering_num_neighbors" numSearchThreads = "num_search_threads" maxNumPrefetchDatasets = "max_num_prefetch_datasets" + storage_engine = "storage_engine" + max_cache_size = "max_cache_size" class CustomizedCase(BaseModel):