diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..cc30f74dd --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +[*] +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = false +insert_final_newline = false + +[Dockerfile*] +indent_style = space +indent_size = 4 + +[*.json] +indent_style = space +indent_size = 4 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 004524444..8d28d7eeb 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,18 @@ __MACOSX build/ venv/ .idea/ + +# vscode files +.vscode/* +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/settings.json +!.vscode/*.code-snippets +!.vscode/c_cpp_properties.json + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..fb8612a23 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,28 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Streamlit", + "type": "debugpy", + "request": "launch", + "module": "streamlit", + "args": [ + "run", + "vectordb_bench/frontend/vdb_benchmark.py", + "--logger.level", + "info", + "--theme.base", + "light", + "--theme.primaryColor", + "#3670F2", + "--theme.secondaryBackgroundColor", + "#F0F2F6", + ], + "subProcess": true, + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..88f75752b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "[python]": { + "editor.formatOnSave": false, + // "editor.codeActionsOnSave": { + // "source.fixAll": "always", + // "source.organizeImports": "always" + // }, + "editor.defaultFormatter": "charliermarsh.ruff" + } +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 000000000..085a47bd1 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,29 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "build vectordb bench", + "type": "shell", + "command": "python", + "args": [ + "-m", + "pip", + "install", + "-e", + ".[test]" + ], + "group": { + "kind": "build", + "isDefault": true + } + }, + { + "label": "run vectordb bench", + "type": "shell", + "command": "init_bench", + "problemMatcher": [] + } + ] +} \ No newline at end of file diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 3df11610b..4662ea53c 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -32,6 +32,7 @@ class DB(Enum): PgVectoRS = "PgVectoRS" Redis = "Redis" Chroma = "Chroma" + Hippo = "Hippo" @property @@ -76,6 +77,10 @@ def init_cls(self) -> Type[VectorDB]: if self == DB.Chroma: from .chroma.chroma import ChromaClient return ChromaClient + + if self == DB.Hippo: + from .hippo.hippo import Hippo + return Hippo @property def config_cls(self) -> Type[DBConfig]: @@ -120,6 +125,10 @@ def config_cls(self) -> Type[DBConfig]: from .chroma.config import ChromaConfig return ChromaConfig + if self == DB.Hippo: + from .hippo.config import HippoConfig + return HippoConfig + def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]: if self == DB.Milvus: from .milvus.config import _milvus_case_config @@ -149,6 +158,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon from .pgvecto_rs.config import _pgvecto_rs_case_config return _pgvecto_rs_case_config.get(index_type) + if self == DB.Hippo: + from .hippo.config import HippoIndexConfig + return HippoIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/hippo/config.py b/vectordb_bench/backend/clients/hippo/config.py new file mode 100644 index 000000000..0c9318ca0 --- /dev/null +++ b/vectordb_bench/backend/clients/hippo/config.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, Field, SecretStr +from transwarp_hippo_api.hippo_type import IndexType +from transwarp_hippo_api.hippo_type import MetricType as HippoMetricType + +from ..api import DBCaseConfig, DBConfig, MetricType + + +class HippoConfig(DBConfig): + ip: SecretStr = "" + port: SecretStr = "18902" + username: SecretStr = "shiva" + password: SecretStr = "shiva" + number_of_shards: int = Field(default=1, ge=1) + number_of_replicas: int = Field(default=1, ge=1) + insert_batch_size: int = Field(default=100, ge=1) + + def to_dict(self) -> dict: + return { + "host_port": [ + f"{self.ip.get_secret_value()}:{self.port.get_secret_value()}" + ], + "username": self.username.get_secret_value(), + "pwd": self.password.get_secret_value(), + "number_of_shards": self.number_of_shards, + "number_of_replicas": self.number_of_replicas, + "insert_batch_size": self.insert_batch_size, + } + + +class HippoIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.HNSW # HNSW, FLAT, IVF_FLAT, IVF_SQ, IVF_PQ, ANNOY + metric_type: MetricType | None = None + M: int = 30 # [4,96] + ef_construction: int = 360 # [8, 512] + ef_search: int = 100 # [topk, 32768] + nlist: int = 1024 # [1,65536] + nprobe: int = 64 # [1, nlist] + m: int = 16 # divisible by dim + nbits: int = 8 # [1, 16] + k_factor: int = 100 # [10, 1000] + + def parse_metric(self) -> HippoMetricType: + if self.metric_type == MetricType.COSINE: + return HippoMetricType.COSINE + if self.metric_type == MetricType.IP: + return HippoMetricType.IP + if self.metric_type == MetricType.L2: + return HippoMetricType.L2 + return "" + + def index_param(self) -> dict: + return { + "M": self.M, + "ef_construction": self.ef_construction, + "ef_search": self.ef_search, + "nlist": self.nlist, + "nprobe": self.nprobe, + "m": self.m, + "nbits": self.nbits, + } + + def search_param(self) -> dict: + return { + "ef_search": self.ef_search, + "nprobe": self.nprobe, + "k_factor": self.k_factor, + } diff --git a/vectordb_bench/backend/clients/hippo/hippo.py b/vectordb_bench/backend/clients/hippo/hippo.py new file mode 100644 index 000000000..3f17bd302 --- /dev/null +++ b/vectordb_bench/backend/clients/hippo/hippo.py @@ -0,0 +1,212 @@ +import logging +from contextlib import contextmanager +from typing import Iterable + +import numpy as np +from transwarp_hippo_api.hippo_client import HippoClient, HippoField +from transwarp_hippo_api.hippo_type import HippoType + +from ..api import VectorDB +from .config import HippoIndexConfig + +log = logging.getLogger(__name__) + + +class Hippo(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: HippoIndexConfig, + drop_old: bool = False, + **kwargs, + ): + """Initialize wrapper around the hippo vector database.""" + self.name = "Hippo" + self.db_config = db_config + self.index_config = db_case_config + + self.database_name = "default" + self.table_name = "vdbbench_table" + self.index_name = "vector_index" + + self.vector_field_name = "vector" + self.int_field_name = "label" + self.pk_field_name = "pk" + + self.insert_batch_size = db_config.get("insert_batch_size") + self.activated = False + + # if `drop_old`, check table and delete table + hc = HippoClient( + **{ + k: db_config[k] + for k in ["host_port", "username", "pwd"] + if k in db_config + } + ) + if drop_old: + try: + table_check = hc.check_table_exists( + self.table_name, database_name=self.database_name + ) + log.info(f"check table exsited: {table_check}") + except ValueError as e: + log.error("failed to check table exsited; skip", exc_info=e) + table_check = False + + if table_check: + log.info(f"delete table: {self.table_name}") + hc.delete_table(self.table_name, database_name=self.database_name) + hc.delete_table_in_trash( + self.table_name, database_name=self.database_name + ) + + # create table + fields = [ + HippoField(self.pk_field_name, True, HippoType.INT64), + HippoField(self.int_field_name, False, HippoType.INT64), + HippoField( + self.vector_field_name, + False, + HippoType.FLOAT_VECTOR, + type_params={"dimension": dim}, + ), + ] + log.info(f"create table: {self.table_name}") + hc.create_table( + name=self.table_name, + fields=fields, + database_name=self.database_name, + number_of_shards=db_config.get("number_of_shards"), + number_of_replicas=db_config.get("number_of_replicas"), + ) + + table = hc.get_table(self.table_name, database_name=self.database_name) + # create index + log.info("create index") + table.create_index( + field_name=self.vector_field_name, + index_name=self.index_name, + index_type=self.index_config.index, + metric_type=self.index_config.parse_metric(), + **self.index_config.index_param(), + ) + + def need_normalize_cosine(self) -> bool: + """Wheather this database need to normalize dataset to support COSINE""" + return False + + @contextmanager + def init(self): + """ + generate connection + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + from transwarp_hippo_api.hippo_client import HippoClient + + hc = HippoClient( + **{ + k: self.db_config[k] + for k in ["host_port", "username", "pwd"] + if k in self.db_config + } + ) + self.client = hc.get_table(self.table_name, database_name=self.database_name) + + yield + + def _activate_index(self): + if not self.activated: + try: + log.info("start activate index, please wait ...") + self.client.activate_index( + self.index_name, wait_for_completion=True, timeout="25h" + ) + log.info("index is actived.") + except Exception as e: + log.error("failed to activate index; skip", exc_info=e) + + self.activated = True + + def insert_embeddings( + self, embeddings: Iterable[list[float]], metadata: list[int], **kwargs + ): + assert self.client is not None + insert_count = 0 + try: + for batch_start_offset in range(0, len(embeddings), self.insert_batch_size): + log.info("batch offset: %d", batch_start_offset) + + data = [ + list( + metadata[ + batch_start_offset : batch_start_offset + + self.insert_batch_size + ] + ), + list( + metadata[ + batch_start_offset : batch_start_offset + + self.insert_batch_size + ] + ), + [ + i.tolist() if isinstance(i, np.ndarray) else i + for i in embeddings[ + batch_start_offset : batch_start_offset + + self.insert_batch_size + ] + ], + ] + + self.client.insert_rows(data) + insert_count += len(data[0]) + # if kwargs.get("last_batch"): + # self._activate_index() + except Exception as e: + log.error("hippp insert error", exc_info=e) + return (insert_count, e) + + log.info("total insert: %d", insert_count) + + return (insert_count, None) + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + # assert self.col is not None + + dsl = f"{self.int_field_name} >= {filters['id']}" if filters else "" + output_fields = [self.int_field_name] + result = self.client.query( + self.vector_field_name, + [query], + output_fields, + k, + dsl=dsl, + **self.index_config.search_param(), + ) + + return result[0][self.int_field_name] + + def optimize(self, **kwargs): + self._activate_index() + + if kwargs.get("filters"): + log.info(f"create scalar index on field: {self.int_field_name}") + self.client.create_scalar_index( + field_names=[self.int_field_name], + index_name="idx_" + self.int_field_name, + ) + log.info("scalar index created") + + def ready_to_load(self): + return diff --git a/vectordb_bench/backend/clients/hippo/test.py b/vectordb_bench/backend/clients/hippo/test.py new file mode 100644 index 000000000..2d20d711b --- /dev/null +++ b/vectordb_bench/backend/clients/hippo/test.py @@ -0,0 +1,93 @@ +from transwarp_hippo_api.hippo_client import HippoClient, HippoField +from transwarp_hippo_api.hippo_type import HippoType, IndexType, MetricType +import numpy as np + +ip = "" +port = "" +username = "" +pwd = "" + +dim = 128 +n_train = 10000 +n_test = 100 + +# connect +hc = HippoClient([f"{ip}:{port}"], username=username, pwd=pwd) + +# create database +database_name = "default" +# db = hc.create_database(database_name) + +# create table +table_name = "vdbbench_table" +# table_check = hc.check_table_exists(table_name, database_name=database_name) +# if table_check: +# hc.delete_table(table_name, database_name=database_name) +# hc.delete_table_in_trash(table_name, database_name=database_name) +vector_field_name = "vector" +int_field_name = "label" +pk_field_name = "pk" +fields = [ + HippoField(pk_field_name, True, HippoType.INT64), + HippoField(int_field_name, False, HippoType.INT64), + HippoField(vector_field_name, False, HippoType.FLOAT_VECTOR, + type_params={"dimension": dim}), +] +client = hc.create_table(name=table_name, fields=fields, + database_name=database_name, number_of_shards=1, number_of_replicas=1) + + +# get table +client = hc.get_table(table_name, database_name=database_name) + + +# create index +index_name = "vector_index" +M = 30 # [4,96] +ef_construction = 360 # [8, 512] +ef_search = 100 # [topk, 32768] +client.create_index(field_name=vector_field_name, index_name=index_name, + index_type=IndexType.HNSW, metric_type=MetricType.L2, + M=M, ef_construction=ef_construction, ef_search=ef_search) + + +# # load? +# index_loaded = client.load_index(index_name) + +# insert +pk_data = np.arange(n_train) +int_data = np.random.randint(0, 100, n_train) +vector_data = np.random.rand(n_train, dim) +batch_size = 100 +for offset in range(0, n_train, batch_size): + start = offset + end = offset + batch_size + print(f"insert {start}-{end}") + data = [ + pk_data[start:end].tolist(), int_data[start:end].tolist( + ), vector_data[start:end].tolist(), + ] + client.insert_rows(data) + +# need activate - like milvus load +client.activate_index(index_name, wait_for_completion=True, timeout="25h") + +# ann search +query_vectors = np.random.rand(n_test, dim) +output_fields = [pk_field_name, int_field_name] +k = 10 +dsl = f"{int_field_name} >= 90" +result = client.query(vector_field_name, query_vectors.tolist(), + output_fields, topk=k, dsl=dsl) +print(result[0]) + +result = client.query(vector_field_name, query_vectors.tolist(), + output_fields, topk=100) +print(result[0]) + +# delete table +hc.delete_table(table_name, database_name=database_name) +hc.delete_table_in_trash(table_name, database_name=database_name) + +# # delete database +# hc.delete_database(database_name) diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index aeed0ec74..e11822f9a 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -47,6 +47,7 @@ def task(self) -> int: log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}") last_batch = self.dataset.data.size - count == len(all_metadata) + log.info(f"last_batch: {last_batch} data size: {self.dataset.data.size} count: {count}") insert_count, error = self.db.insert_embeddings( embeddings=all_embeddings, metadata=all_metadata, @@ -114,7 +115,7 @@ def _insert_all_batches(self) -> int: psutil.Process(pid).kill() raise PerformanceTimeoutError(msg) from e except Exception as e: - log.warning(f"VectorDB load dataset error: {e}") + log.error("VectorDB load dataset error: ", exc_info=e) raise e from e else: return count @@ -169,7 +170,7 @@ def __init__( self.test_data = test_data self.ground_truth = ground_truth - def search(self, args: tuple[list, pd.DataFrame]): + def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency") with self.db.init(): test_data, ground_truth = args @@ -213,14 +214,14 @@ def search(self, args: tuple[list, pd.DataFrame]): f"avg_latency={avg_latency}, " f"p99={p99}" ) - return (avg_recall, p99) + return (avg_recall, p99, avg_latency) - def _run_in_subprocess(self) -> tuple[float, float]: + def _run_in_subprocess(self) -> tuple[float, float, float]: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: future = executor.submit(self.search, (self.test_data, self.ground_truth)) result = future.result() return result - def run(self) -> tuple[float, float]: + def run(self) -> tuple[float, float, float]: return self._run_in_subprocess() diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 80c5ac1df..b78a6bab3 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -58,7 +58,7 @@ def __eq__(self, obj): self.config.db == obj.config.db and \ self.config.db_case_config == obj.config.db_case_config and \ self.ca.dataset == obj.ca.dataset - return False + return False def display(self) -> dict: c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} }) @@ -140,7 +140,7 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: ) self._init_search_runner() - m.recall, m.serial_latency_p99 = self._serial_search() + m.recall, m.serial_latency_p99, m.serial_latency_avg = self._serial_search() m.qps = self._conc_search() except Exception as e: log.warning(f"Failed to run performance case, reason = {e}") @@ -161,7 +161,7 @@ def _load_train_data(self): finally: runner = None - def _serial_search(self) -> tuple[float, float]: + def _serial_search(self) -> tuple[float, float, float]: """Performance serial tests, search the entire test data once, calculate the recall, serial_latency_p99 @@ -193,7 +193,7 @@ def _conc_search(self): @utils.time_it def _task(self) -> None: with self.db.init(): - self.db.optimize() + self.db.optimize(filters=self.ca.filters) def _optimize(self) -> float: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py index 61db843f3..40c5ee5f9 100644 --- a/vectordb_bench/frontend/components/run_test/dbSelector.py +++ b/vectordb_bench/frontend/components/run_test/dbSelector.py @@ -30,7 +30,7 @@ def dbSelector(st): for i, db in enumerate(DB_LIST): column = dbContainerColumns[i % DB_SELECTOR_COLUMNS] dbIsActived[db] = column.checkbox(db.name) - column.image(DB_TO_ICON.get(db, "")) + column.image(DB_TO_ICON.get(db, ""), width=100) activedDbList = [db for db in DB_LIST if dbIsActived[db]] return activedDbList diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/const/dbCaseConfigs.py index fad5f362d..ad6a3b73f 100644 --- a/vectordb_bench/frontend/const/dbCaseConfigs.py +++ b/vectordb_bench/frontend/const/dbCaseConfigs.py @@ -1,10 +1,12 @@ -from enum import IntEnum import typing +from enum import IntEnum + from pydantic import BaseModel +from transwarp_hippo_api.hippo_type import IndexType as HippoIndexType + from vectordb_bench.backend.cases import CaseLabel, CaseType from vectordb_bench.backend.clients import DB from vectordb_bench.backend.clients.api import IndexType - from vectordb_bench.models import CaseConfigParamType MAX_STREAMLIT_INT = (1 << 53) - 1 @@ -419,6 +421,156 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_Hippo = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputType=InputType.Option, + inputConfig={ + "options": [ + HippoIndexType.HNSW.value, + HippoIndexType.FLAT.value, + HippoIndexType.IVF_FLAT.value, + HippoIndexType.IVF_SQ.value, + HippoIndexType.IVF_PQ.value, + HippoIndexType.ANNOY.value, + ], + }, +) + +CaseConfigParamInput_M_Hippo = CaseConfigInput( + label=CaseConfigParamType.M, + inputType=InputType.Number, + inputConfig={ + "min": 4, + "max": 64, + "value": 30, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == HippoIndexType.HNSW.value, +) + +CaseConfigParamInput_EFConstruction_Hippo = CaseConfigInput( + label=CaseConfigParamType.ef_construction, + inputType=InputType.Number, + inputConfig={ + "min": 8, + "max": 512, + "value": 360, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == HippoIndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_Hippo = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputType=InputType.Number, + inputConfig={ + "min": 100, + "max": MAX_STREAMLIT_INT, + "value": 100, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == HippoIndexType.HNSW.value, +) + +CaseConfigParamInput_Nlist_Hippo = CaseConfigInput( + label=CaseConfigParamType.Nlist, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 65536, + "value": 1024, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_FLAT.value, + HippoIndexType.IVF_SQ.value, + HippoIndexType.IVF_PQ.value, + # TODO: add ivf_pq_fs + ], +) + +CaseConfigParamInput_Nprobe_Hippo = CaseConfigInput( + label=CaseConfigParamType.Nprobe, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 65536, + "value": 64, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_FLAT.value, + HippoIndexType.IVF_SQ.value, + HippoIndexType.IVF_PQ.value, + # TODO: add ivf_pq_fs + ], +) + +CaseConfigParamInput_m_Hippo = CaseConfigInput( + label=CaseConfigParamType.m, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 1024, + "value": 16, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_PQ.value, + # TODO: add ivf_pq_fs + ], +) + +CaseConfigParamInput_nbits_Hippo = CaseConfigInput( + label=CaseConfigParamType.nbits, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 16, + "value": 8, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_PQ.value, + ], +) + +CaseConfigParamInput_k_factor_Hippo = CaseConfigInput( + label=CaseConfigParamType.k_factor, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 1000, + "value": 100, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_PQ.value, + HippoIndexType.IVF_SQ.value, + # TODO: add ivf_pq_fs + ], +) + +CaseConfigParamInput_index_slow_refine_Hippo = CaseConfigInput( + label=CaseConfigParamType.index_slow_refine, + inputType=InputType.Option, + inputConfig={"options": [False, True]}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + HippoIndexType.IVF_PQ.value, + HippoIndexType.IVF_SQ.value, + # TODO: add ivf_pq_fs + ], +) + +CaseConfigParamInput_sq_type_Hippo = CaseConfigInput( + label=CaseConfigParamType.sq_type, + inputType=InputType.Text, + inputConfig={"value": ""}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == HippoIndexType.IVF_SQ.value, +) + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, CaseConfigParamInput_M, @@ -496,6 +648,29 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_ZillizLevel, ] +HippoLoadConfig = [ + CaseConfigParamInput_IndexType_Hippo, + CaseConfigParamInput_M_Hippo, + CaseConfigParamInput_EFConstruction_Hippo, + CaseConfigParamInput_EFSearch_Hippo, + CaseConfigParamInput_Nlist_Hippo, + CaseConfigParamInput_m_Hippo, + CaseConfigParamInput_nbits_Hippo, +] +HippoPerformanceConfig = [ + CaseConfigParamInput_IndexType_Hippo, + CaseConfigParamInput_M_Hippo, + CaseConfigParamInput_EFConstruction_Hippo, + CaseConfigParamInput_EFSearch_Hippo, + CaseConfigParamInput_Nlist_Hippo, + CaseConfigParamInput_Nprobe_Hippo, + CaseConfigParamInput_m_Hippo, + CaseConfigParamInput_nbits_Hippo, + CaseConfigParamInput_k_factor_Hippo, + CaseConfigParamInput_index_slow_refine_Hippo, + CaseConfigParamInput_sq_type_Hippo, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -520,4 +695,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: PgVectoRSLoadingConfig, CaseLabel.Performance: PgVectoRSPerformanceConfig, }, + DB.Hippo: { + CaseLabel.Load: HippoLoadConfig, + CaseLabel.Performance: HippoPerformanceConfig, + }, } diff --git a/vectordb_bench/frontend/const/styles.py b/vectordb_bench/frontend/const/styles.py index 52d1017a9..ad3e4d9c0 100644 --- a/vectordb_bench/frontend/const/styles.py +++ b/vectordb_bench/frontend/const/styles.py @@ -46,6 +46,7 @@ def getPatternShape(i): DB.PgVectoRS: "https://assets.zilliz.com/PG_Vector_d464f2ef5f.png", DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png", DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", + DB.Hippo: "https://assets.zilliz.com/hippo_3ce85bc90f.png", } # RedisCloud color: #0D6EFD @@ -59,4 +60,5 @@ def getPatternShape(i): DB.WeaviateCloud.value: "#20C997", DB.PgVector.value: "#4C779A", DB.Redis.value: "#0D6EFD", + DB.Hippo.value: "#333", } diff --git a/vectordb_bench/metric.py b/vectordb_bench/metric.py index a2b6d6ff0..be90f3adb 100644 --- a/vectordb_bench/metric.py +++ b/vectordb_bench/metric.py @@ -18,12 +18,14 @@ class Metric: load_duration: float = 0.0 # duration to load all dataset into DB qps: float = 0.0 serial_latency_p99: float = 0.0 + serial_latency_avg: float = 0.0 recall: float = 0.0 QURIES_PER_DOLLAR_METRIC = "QP$ (Quries per Dollar)" LOAD_DURATION_METRIC = "load_duration" SERIAL_LATENCY_P99_METRIC = "serial_latency_p99" +SERIAL_LATENCY_AVG_METRIC = "serial_latency_avg" MAX_LOAD_COUNT_METRIC = "max_load_count" QPS_METRIC = "qps" RECALL_METRIC = "recall" @@ -31,6 +33,7 @@ class Metric: metricUnitMap = { LOAD_DURATION_METRIC: "s", SERIAL_LATENCY_P99_METRIC: "ms", + SERIAL_LATENCY_AVG_METRIC: "ms", MAX_LOAD_COUNT_METRIC: "K", QURIES_PER_DOLLAR_METRIC: "K", } @@ -38,6 +41,7 @@ class Metric: lowerIsBetterMetricList = [ LOAD_DURATION_METRIC, SERIAL_LATENCY_P99_METRIC, + SERIAL_LATENCY_AVG_METRIC, ] metricOrder = [ @@ -45,6 +49,7 @@ class Metric: RECALL_METRIC, LOAD_DURATION_METRIC, SERIAL_LATENCY_P99_METRIC, + SERIAL_LATENCY_AVG_METRIC, MAX_LOAD_COUNT_METRIC, ] diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 3c2a5b9aa..da7da9ca6 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -1,23 +1,22 @@ import logging import pathlib from datetime import date -from typing import Self from enum import Enum +from typing import Self import ujson +from . import config +from .backend.cases import CaseType from .backend.clients import ( DB, - DBConfig, DBCaseConfig, + DBConfig, IndexType, ) -from .backend.cases import CaseType from .base import BaseModel -from . import config from .metric import Metric - log = logging.getLogger(__name__) @@ -60,6 +59,11 @@ class CaseConfigParamType(Enum): cache_dataset_on_device = "cache_dataset_on_device" refine_ratio = "refine_ratio" level = "level" + ef_construction = "ef_construction" + ef_search = "ef_search" + k_factor = "k_factor" + index_slow_refine = "index_slow_refine" + sq_type = "sq_type" class CustomizedCase(BaseModel):