Skip to content

Commit

Permalink
Add cli support for running benchmark with custom dataset in pgvector…
Browse files Browse the repository at this point in the history
… hnsw and ivvfflat.
  • Loading branch information
Sheharyar570 committed Oct 2, 2024
1 parent b364fe3 commit d8bd961
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
3 changes: 3 additions & 0 deletions vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IVFFlatTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
get_custom_case_config,
run,
)
from vectordb_bench.backend.clients import DB
Expand Down Expand Up @@ -77,6 +78,7 @@ def PgVectorIVFFlat(
):
from .config import PgVectorConfig, PgVectorIVFFlatConfig

parameters["custom_case"] = get_custom_case_config(parameters)
run(
db=DB.PgVector,
db_config=PgVectorConfig(
Expand Down Expand Up @@ -107,6 +109,7 @@ def PgVectorHNSW(
):
from .config import PgVectorConfig, PgVectorHNSWConfig

parameters["custom_case"] = get_custom_case_config(parameters)
run(
db=DB.PgVector,
db_config=PgVectorConfig(
Expand Down
137 changes: 137 additions & 0 deletions vectordb_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Any,
)
import click

from vectordb_bench.backend.clients.api import MetricType
from .. import config
from ..backend.clients import DB
from ..interface import benchMarkRunner, global_result_future
Expand Down Expand Up @@ -147,6 +149,37 @@ def parse_task_stages(
return stages


def check_custom_case_parameters(ctx, param, value):
if ctx.params.get("case_type") == "PerformanceCustomDataset":
if value is None:
raise click.BadParameter("Custom case parameters\
\n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \
\n--custom-dataset-dim\n--custom-dataset-file-count\n are required")
return value


def get_custom_case_config(parameters: dict) -> dict:
custom_case_config = {}
if parameters["case_type"] == "PerformanceCustomDataset":
custom_case_config = {
"name": parameters["custom_case_name"],
"description": parameters["custom_case_description"],
"load_timeout": parameters["custom_case_load_timeout"],
"optimize_timeout": parameters["custom_case_optimize_timeout"],
"dataset_config": {
"name": parameters["custom_dataset_name"],
"dir": parameters["custom_dataset_dir"],
"size": parameters["custom_dataset_size"],
"dim": parameters["custom_dataset_dim"],
"metric_type": parameters["custom_dataset_metric_type"],
"file_count": parameters["custom_dataset_file_count"],
"use_shuffled": parameters["custom_dataset_use_shuffled"],
"with_gt": parameters["custom_dataset_with_gt"],
}
}
return custom_case_config


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -205,6 +238,7 @@ class CommonTypedDict(TypedDict):
click.option(
"--case-type",
type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]),
is_eager=True,
default="Performance1536D50K",
help="Case type",
),
Expand Down Expand Up @@ -258,6 +292,108 @@ class CommonTypedDict(TypedDict):
callback=lambda *args: list(map(int, click_arg_split(*args))),
),
]
custom_case_name: Annotated[
str,
click.option(
"--custom-case-name",
help="Custom dataset case name",
callback=check_custom_case_parameters,
)
]
custom_case_description: Annotated[
str,
click.option(
"--custom-case-description",
help="Custom dataset case description",
default="This is a customized dataset.",
show_default=True,
)
]
custom_case_load_timeout: Annotated[
int,
click.option(
"--custom-case-load-timeout",
help="Custom dataset case load timeout",
default=36000,
show_default=True,
)
]
custom_case_optimize_timeout: Annotated[
int,
click.option(
"--custom-case-optimize-timeout",
help="Custom dataset case optimize timeout",
default=36000,
show_default=True,
)
]
custom_dataset_name: Annotated[
str,
click.option(
"--custom-dataset-name",
help="Custom dataset name",
callback=check_custom_case_parameters,
),
]
custom_dataset_dir: Annotated[
str,
click.option(
"--custom-dataset-dir",
help="Custom dataset directory",
callback=check_custom_case_parameters,
),
]
custom_dataset_size: Annotated[
int,
click.option(
"--custom-dataset-size",
help="Custom dataset size",
callback=check_custom_case_parameters,
),
]
custom_dataset_dim: Annotated[
int,
click.option(
"--custom-dataset-dim",
help="Custom dataset dimension",
callback=check_custom_case_parameters,
),
]
custom_dataset_metric_type: Annotated[
str,
click.option(
"--custom-dataset-metric-type",
help="Custom dataset metric type",
default=MetricType.COSINE.name,
show_default=True,
),
]
custom_dataset_file_count: Annotated[
int,
click.option(
"--custom-dataset-file-count",
help="Custom dataset file count",
callback=check_custom_case_parameters,
),
]
custom_dataset_use_shuffled: Annotated[
bool,
click.option(
"--custom-dataset-use-shuffled/--no-custom-dataset-use-shuffled",
help="Custom dataset use shuffled",
default=False,
show_default=True,
),
]
custom_dataset_with_gt: Annotated[
bool,
click.option(
"--custom-dataset-with-gt/--no-custom-dataset-with-gt",
help="Custom dataset with ground truth",
default=True,
show_default=True,
),
]


class HNSWBaseTypedDict(TypedDict):
Expand Down Expand Up @@ -343,6 +479,7 @@ def run(
concurrency_duration=parameters["concurrency_duration"],
num_concurrency=[int(s) for s in parameters["num_concurrency"]],
),
custom_case=parameters["custom_case"],
),
stages=parse_task_stages(
(
Expand Down

0 comments on commit d8bd961

Please sign in to comment.