Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cli support for running benchmark with custom dataset #372

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IVFFlatTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
get_custom_case_config,
run,
)
from vectordb_bench.backend.clients import DB
Expand Down Expand Up @@ -77,6 +78,7 @@ def PgVectorIVFFlat(
):
from .config import PgVectorConfig, PgVectorIVFFlatConfig

parameters["custom_case"] = get_custom_case_config(parameters)
run(
db=DB.PgVector,
db_config=PgVectorConfig(
Expand Down Expand Up @@ -107,6 +109,7 @@ def PgVectorHNSW(
):
from .config import PgVectorConfig, PgVectorHNSWConfig

parameters["custom_case"] = get_custom_case_config(parameters)
run(
db=DB.PgVector,
db_config=PgVectorConfig(
Expand Down
137 changes: 137 additions & 0 deletions vectordb_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Any,
)
import click

from vectordb_bench.backend.clients.api import MetricType
from .. import config
from ..backend.clients import DB
from ..interface import benchMarkRunner, global_result_future
Expand Down Expand Up @@ -147,6 +149,37 @@ def parse_task_stages(
return stages


def check_custom_case_parameters(ctx, param, value):
if ctx.params.get("case_type") == "PerformanceCustomDataset":
if value is None:
raise click.BadParameter("Custom case parameters\
\n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \
\n--custom-dataset-dim\n--custom-dataset-file-count\n are required")
return value


def get_custom_case_config(parameters: dict) -> dict:
custom_case_config = {}
if parameters["case_type"] == "PerformanceCustomDataset":
custom_case_config = {
"name": parameters["custom_case_name"],
"description": parameters["custom_case_description"],
"load_timeout": parameters["custom_case_load_timeout"],
"optimize_timeout": parameters["custom_case_optimize_timeout"],
"dataset_config": {
"name": parameters["custom_dataset_name"],
"dir": parameters["custom_dataset_dir"],
"size": parameters["custom_dataset_size"],
"dim": parameters["custom_dataset_dim"],
"metric_type": parameters["custom_dataset_metric_type"],
"file_count": parameters["custom_dataset_file_count"],
"use_shuffled": parameters["custom_dataset_use_shuffled"],
"with_gt": parameters["custom_dataset_with_gt"],
}
}
return custom_case_config


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -205,6 +238,7 @@ class CommonTypedDict(TypedDict):
click.option(
"--case-type",
type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]),
is_eager=True,
default="Performance1536D50K",
help="Case type",
),
Expand Down Expand Up @@ -258,6 +292,108 @@ class CommonTypedDict(TypedDict):
callback=lambda *args: list(map(int, click_arg_split(*args))),
),
]
custom_case_name: Annotated[
str,
click.option(
"--custom-case-name",
help="Custom dataset case name",
callback=check_custom_case_parameters,
)
]
custom_case_description: Annotated[
str,
click.option(
"--custom-case-description",
help="Custom dataset case description",
default="This is a customized dataset.",
show_default=True,
)
]
custom_case_load_timeout: Annotated[
int,
click.option(
"--custom-case-load-timeout",
help="Custom dataset case load timeout",
default=36000,
show_default=True,
)
]
custom_case_optimize_timeout: Annotated[
int,
click.option(
"--custom-case-optimize-timeout",
help="Custom dataset case optimize timeout",
default=36000,
show_default=True,
)
]
custom_dataset_name: Annotated[
str,
click.option(
"--custom-dataset-name",
help="Custom dataset name",
callback=check_custom_case_parameters,
),
]
custom_dataset_dir: Annotated[
str,
click.option(
"--custom-dataset-dir",
help="Custom dataset directory",
callback=check_custom_case_parameters,
),
]
custom_dataset_size: Annotated[
int,
click.option(
"--custom-dataset-size",
help="Custom dataset size",
callback=check_custom_case_parameters,
),
]
custom_dataset_dim: Annotated[
int,
click.option(
"--custom-dataset-dim",
help="Custom dataset dimension",
callback=check_custom_case_parameters,
),
]
custom_dataset_metric_type: Annotated[
str,
click.option(
"--custom-dataset-metric-type",
help="Custom dataset metric type",
default=MetricType.COSINE.name,
show_default=True,
),
]
custom_dataset_file_count: Annotated[
int,
click.option(
"--custom-dataset-file-count",
help="Custom dataset file count",
callback=check_custom_case_parameters,
),
]
custom_dataset_use_shuffled: Annotated[
bool,
click.option(
"--custom-dataset-use-shuffled/--no-custom-dataset-use-shuffled",
help="Custom dataset use shuffled",
default=False,
show_default=True,
),
]
custom_dataset_with_gt: Annotated[
bool,
click.option(
"--custom-dataset-with-gt/--no-custom-dataset-with-gt",
help="Custom dataset with ground truth",
default=True,
show_default=True,
),
]


class HNSWBaseTypedDict(TypedDict):
Expand Down Expand Up @@ -343,6 +479,7 @@ def run(
concurrency_duration=parameters["concurrency_duration"],
num_concurrency=[int(s) for s in parameters["num_concurrency"]],
),
custom_case=parameters["custom_case"],
),
stages=parse_task_stages(
(
Expand Down