From d8bd9611cf926e6842713afe0f6db30c29d2bfb6 Mon Sep 17 00:00:00 2001 From: Sheharyar Ahmad Date: Wed, 2 Oct 2024 23:23:32 +0500 Subject: [PATCH] Add cli support for running benchmark with custom dataset in pgvector hnsw and ivvfflat. --- .../backend/clients/pgvector/cli.py | 3 + vectordb_bench/cli/cli.py | 137 ++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 4e0922694..d5779caf6 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -10,6 +10,7 @@ IVFFlatTypedDict, cli, click_parameter_decorators_from_typed_dict, + get_custom_case_config, run, ) from vectordb_bench.backend.clients import DB @@ -77,6 +78,7 @@ def PgVectorIVFFlat( ): from .config import PgVectorConfig, PgVectorIVFFlatConfig + parameters["custom_case"] = get_custom_case_config(parameters) run( db=DB.PgVector, db_config=PgVectorConfig( @@ -107,6 +109,7 @@ def PgVectorHNSW( ): from .config import PgVectorConfig, PgVectorHNSWConfig + parameters["custom_case"] = get_custom_case_config(parameters) run( db=DB.PgVector, db_config=PgVectorConfig( diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 00910261b..950bf730b 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -17,6 +17,8 @@ Any, ) import click + +from vectordb_bench.backend.clients.api import MetricType from .. import config from ..backend.clients import DB from ..interface import benchMarkRunner, global_result_future @@ -147,6 +149,37 @@ def parse_task_stages( return stages +def check_custom_case_parameters(ctx, param, value): + if ctx.params.get("case_type") == "PerformanceCustomDataset": + if value is None: + raise click.BadParameter("Custom case parameters\ + \n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \ + \n--custom-dataset-dim\n--custom-dataset-file-count\n are required") + return value + + +def get_custom_case_config(parameters: dict) -> dict: + custom_case_config = {} + if parameters["case_type"] == "PerformanceCustomDataset": + custom_case_config = { + "name": parameters["custom_case_name"], + "description": parameters["custom_case_description"], + "load_timeout": parameters["custom_case_load_timeout"], + "optimize_timeout": parameters["custom_case_optimize_timeout"], + "dataset_config": { + "name": parameters["custom_dataset_name"], + "dir": parameters["custom_dataset_dir"], + "size": parameters["custom_dataset_size"], + "dim": parameters["custom_dataset_dim"], + "metric_type": parameters["custom_dataset_metric_type"], + "file_count": parameters["custom_dataset_file_count"], + "use_shuffled": parameters["custom_dataset_use_shuffled"], + "with_gt": parameters["custom_dataset_with_gt"], + } + } + return custom_case_config + + log = logging.getLogger(__name__) @@ -205,6 +238,7 @@ class CommonTypedDict(TypedDict): click.option( "--case-type", type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]), + is_eager=True, default="Performance1536D50K", help="Case type", ), @@ -258,6 +292,108 @@ class CommonTypedDict(TypedDict): callback=lambda *args: list(map(int, click_arg_split(*args))), ), ] + custom_case_name: Annotated[ + str, + click.option( + "--custom-case-name", + help="Custom dataset case name", + callback=check_custom_case_parameters, + ) + ] + custom_case_description: Annotated[ + str, + click.option( + "--custom-case-description", + help="Custom dataset case description", + default="This is a customized dataset.", + show_default=True, + ) + ] + custom_case_load_timeout: Annotated[ + int, + click.option( + "--custom-case-load-timeout", + help="Custom dataset case load timeout", + default=36000, + show_default=True, + ) + ] + custom_case_optimize_timeout: Annotated[ + int, + click.option( + "--custom-case-optimize-timeout", + help="Custom dataset case optimize timeout", + default=36000, + show_default=True, + ) + ] + custom_dataset_name: Annotated[ + str, + click.option( + "--custom-dataset-name", + help="Custom dataset name", + callback=check_custom_case_parameters, + ), + ] + custom_dataset_dir: Annotated[ + str, + click.option( + "--custom-dataset-dir", + help="Custom dataset directory", + callback=check_custom_case_parameters, + ), + ] + custom_dataset_size: Annotated[ + int, + click.option( + "--custom-dataset-size", + help="Custom dataset size", + callback=check_custom_case_parameters, + ), + ] + custom_dataset_dim: Annotated[ + int, + click.option( + "--custom-dataset-dim", + help="Custom dataset dimension", + callback=check_custom_case_parameters, + ), + ] + custom_dataset_metric_type: Annotated[ + str, + click.option( + "--custom-dataset-metric-type", + help="Custom dataset metric type", + default=MetricType.COSINE.name, + show_default=True, + ), + ] + custom_dataset_file_count: Annotated[ + int, + click.option( + "--custom-dataset-file-count", + help="Custom dataset file count", + callback=check_custom_case_parameters, + ), + ] + custom_dataset_use_shuffled: Annotated[ + bool, + click.option( + "--custom-dataset-use-shuffled/--no-custom-dataset-use-shuffled", + help="Custom dataset use shuffled", + default=False, + show_default=True, + ), + ] + custom_dataset_with_gt: Annotated[ + bool, + click.option( + "--custom-dataset-with-gt/--no-custom-dataset-with-gt", + help="Custom dataset with ground truth", + default=True, + show_default=True, + ), + ] class HNSWBaseTypedDict(TypedDict): @@ -343,6 +479,7 @@ def run( concurrency_duration=parameters["concurrency_duration"], num_concurrency=[int(s) for s in parameters["num_concurrency"]], ), + custom_case=parameters["custom_case"], ), stages=parse_task_stages( (