Skip to content

Commit

Permalink
support custom_dataset
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 authored and XuanYang-cn committed Jul 17, 2024
1 parent 09306a0 commit 966bd80
Show file tree
Hide file tree
Showing 40 changed files with 611 additions and 207 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ Case No. | Case Type | Dataset Size | Filtering Rate | Results |

Each case provides an in-depth examination of a vector database's abilities, providing you a comprehensive view of the database's performance.

#### Custom Dataset for Performance case

Through the `/custom` page, users can customize their own performance case using local datasets. After saving, the corresponding case can be selected from the `/run_test` page to perform the test.

![image](fig/custom_dataset.png)
![image](fig/custom_case_run_test.png)

We have strict requirements for the data set format, please follow them.
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
- Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
- Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
- Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.

- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.

- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.


## Goals
Our goals of this benchmark are:
### Reproducibility & Usability
Expand Down
Binary file added fig/custom_case_run_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fig/custom_dataset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ zilliz_cloud = []
[project.scripts]
init_bench = "vectordb_bench.__main__:main"
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"

[tool.setuptools_scm]
2 changes: 2 additions & 0 deletions vectordb_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class config:


K_DEFAULT = 100 # default return top k nearest neighbors during search
RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")

CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
Expand Down
2 changes: 1 addition & 1 deletion vectordb_bench/backend/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Assembler:
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls

c = c_cls()
c = c_cls(task.case_config.custom_case)
if type(task.db_case_config) != EmptyDBCaseConfig:
task.db_case_config.metric_type = c.dataset.data.metric_type

Expand Down
82 changes: 64 additions & 18 deletions vectordb_bench/backend/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from typing import Type

from vectordb_bench import config
from vectordb_bench.backend.clients.api import MetricType
from vectordb_bench.base import BaseModel
from vectordb_bench.frontend.components.custom.getCustomConfig import (
CustomDatasetConfig,
)

from .dataset import Dataset, DatasetManager
from .dataset import CustomDataset, Dataset, DatasetManager


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,25 +48,24 @@ class CaseType(Enum):
Performance1536D50K = 50

Custom = 100
PerformanceCustomDataset = 101

@property
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
if self not in type2case:
raise NotImplementedError(f"Case {self} has not implemented. You can add it manually to vectordb_bench.backend.cases.type2case or define a custom_configs['custom_cls']")
return type2case[self]
if custom_configs is None:
return type2case.get(self)()
else:
return type2case.get(self)(**custom_configs)

@property
def case_name(self) -> str:
c = self.case_cls
def case_name(self, custom_configs: dict | None = None) -> str:
c = self.case_cls(custom_configs)
if c is not None:
return c().name
return c.name
raise ValueError("Case unsupported")

@property
def case_description(self) -> str:
c = self.case_cls
def case_description(self, custom_configs: dict | None = None) -> str:
c = self.case_cls(custom_configs)
if c is not None:
return c().description
return c.description
raise ValueError("Case unsupported")


Expand Down Expand Up @@ -289,26 +292,69 @@ class Performance1536D50K(PerformanceCase):
optimize_timeout: float | int | None = 15 * 60


def metric_type_map(s: str) -> MetricType:
if s.lower() == "cosine":
return MetricType.COSINE
if s.lower() == "l2" or s.lower() == "euclidean":
return MetricType.L2
if s.lower() == "ip":
return MetricType.IP
err_msg = f"Not support metric_type: {s}"
log.error(err_msg)
raise RuntimeError(err_msg)


class PerformanceCustomDataset(PerformanceCase):
case_id: CaseType = CaseType.PerformanceCustomDataset
name: str = "Performance With Custom Dataset"
description: str = ""
dataset: DatasetManager

def __init__(
self,
name,
description,
load_timeout,
optimize_timeout,
dataset_config,
**kwargs,
):
dataset_config = CustomDatasetConfig(**dataset_config)
dataset = CustomDataset(
name=dataset_config.name,
size=dataset_config.size,
dim=dataset_config.dim,
metric_type=metric_type_map(dataset_config.metric_type),
use_shuffled=dataset_config.use_shuffled,
with_gt=dataset_config.with_gt,
dir=dataset_config.dir,
file_num=dataset_config.file_count,
)
super().__init__(
name=name,
description=description,
load_timeout=load_timeout,
optimize_timeout=optimize_timeout,
dataset=DatasetManager(data=dataset),
)


type2case = {
CaseType.CapacityDim960: CapacityDim960,
CaseType.CapacityDim128: CapacityDim128,

CaseType.Performance768D100M: Performance768D100M,
CaseType.Performance768D10M: Performance768D10M,
CaseType.Performance768D1M: Performance768D1M,

CaseType.Performance768D10M1P: Performance768D10M1P,
CaseType.Performance768D1M1P: Performance768D1M1P,
CaseType.Performance768D10M99P: Performance768D10M99P,
CaseType.Performance768D1M99P: Performance768D1M99P,

CaseType.Performance1536D500K: Performance1536D500K,
CaseType.Performance1536D5M: Performance1536D5M,

CaseType.Performance1536D500K1P: Performance1536D500K1P,
CaseType.Performance1536D5M1P: Performance1536D5M1P,

CaseType.Performance1536D500K99P: Performance1536D500K99P,
CaseType.Performance1536D5M99P: Performance1536D5M99P,
CaseType.Performance1536D50K: Performance1536D50K,
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
}
32 changes: 27 additions & 5 deletions vectordb_bench/backend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
use_shuffled: bool
with_gt: bool = False
_size_label: dict[int, SizeLabel] = PrivateAttr()
isCustom: bool = False

@validator("size")
def verify_size(cls, v):
Expand All @@ -52,7 +53,27 @@ def dir_name(self) -> str:
def file_count(self) -> int:
return self._size_label.get(self.size).file_count

class CustomDataset(BaseDataset):
dir: str
file_num: int
isCustom: bool = True

@validator("size")
def verify_size(cls, v):
return v

@property
def label(self) -> str:
return "Custom"

@property
def dir_name(self) -> str:
return self.dir

@property
def file_count(self) -> int:
return self.file_num

class LAION(BaseDataset):
name: str = "LAION"
dim: int = 768
Expand Down Expand Up @@ -186,11 +207,12 @@ def prepare(self,
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
all_files.extend([gt_file, test_file])

source.reader().read(
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
)
if not self.data.isCustom:
source.reader().read(
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
)

if gt_file is not None and test_file is not None:
self.test_data = self._read_file(test_file)
Expand Down
18 changes: 18 additions & 0 deletions vectordb_bench/custom/custom_case.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[
{
"name": "My Dataset (Performace Case)",
"description": "this is a customized dataset.",
"load_timeout": 36000,
"optimize_timeout": 36000,
"dataset_config": {
"name": "My Dataset",
"dir": "/my_dataset_path",
"size": 1000000,
"dim": 1024,
"metric_type": "L2",
"file_count": 1,
"use_shuffled": false,
"with_gt": true
}
}
]
12 changes: 6 additions & 6 deletions vectordb_bench/frontend/components/check_results/charts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from vectordb_bench.backend.cases import Case
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
from vectordb_bench.frontend.const.styles import *
from vectordb_bench.frontend.config.styles import *
from vectordb_bench.models import ResultLabel
import plotly.express as px


def drawCharts(st, allData, failedTasks, cases: list[Case]):
def drawCharts(st, allData, failedTasks, caseNames: list[str]):
initMainExpanderStyle(st)
for case in cases:
chartContainer = st.expander(case.name, True)
data = [data for data in allData if data["case_name"] == case.name]
for caseName in caseNames:
chartContainer = st.expander(caseName, True)
data = [data for data in allData if data["case_name"] == caseName]
drawChart(data, chartContainer)

errorDBs = failedTasks[case.name]
errorDBs = failedTasks[caseName]
showFailedDBs(chartContainer, errorDBs)


Expand Down
24 changes: 12 additions & 12 deletions vectordb_bench/frontend/components/check_results/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
def getChartData(
tasks: list[CaseResult],
dbNames: list[str],
cases: list[Case],
caseNames: list[str],
):
filterTasks = getFilterTasks(tasks, dbNames, cases)
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
mergedTasks, failedTasks = mergeTasks(filterTasks)
return mergedTasks, failedTasks


def getFilterTasks(
tasks: list[CaseResult],
dbNames: list[str],
cases: list[Case],
caseNames: list[str],
) -> list[CaseResult]:
case_ids = [case.case_id for case in cases]
filterTasks = [
task
for task in tasks
if task.task_config.db_name in dbNames
and task.task_config.case_config.case_id in case_ids
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
]
return filterTasks

Expand All @@ -36,29 +35,29 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
case_id = task.task_config.case_config.case_id
dbCaseMetricsMap[db_name][case_id] = {
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"metrics": mergeMetrics(
dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
dbCaseMetricsMap[db_name][case.name].get(
"label", ResultLabel.FAILED),
task.label,
),
}

mergedTasks = []
failedTasks = defaultdict(lambda: defaultdict(str))
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
for case_id, metricInfo in caseMetricsMap.items():
for case_name, metricInfo in caseMetricsMap.items():
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
label = metricInfo["label"]
case_name = case_id.case_name
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
Expand All @@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
getBetterMetric(key, value, metrics[key]) if key in metrics else value
getBetterMetric(
key, value, metrics[key]) if key in metrics else value
)

return metrics
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def initMainExpanderStyle(st):
st.markdown(
"""<style>
.main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
.main div[data-testid='stExpander'] {
background-color: #F6F8FA;
border: 1px solid #A9BDD140;
Expand Down
Loading

0 comments on commit 966bd80

Please sign in to comment.