Skip to content

Commit

Permalink
add new db_config for better labeling: version, note
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Jul 24, 2024
1 parent c45876c commit 5ef1daf
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 27 deletions.
21 changes: 20 additions & 1 deletion vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,33 @@ class DBConfig(ABC, BaseModel):
"""

db_label: str = ""
version: str = ""
note: str = ""

@staticmethod
def common_short_configs() -> list[str]:
"""
short input, such as `db_label`, `version`
"""
return ["version", "db_label"]

@staticmethod
def common_long_configs() -> list[str]:
"""
long input, such as `note`
"""
return ["note"]

@abstractmethod
def to_dict(self) -> dict:
raise NotImplementedError

@validator("*")
def not_empty_field(cls, v, field):
if field.name == "db_label":
if (
field.name in cls.common_short_configs()
or field.name in cls.common_long_configs()
):
return v
if not v and isinstance(v, (str, SecretStr)):
raise ValueError("Empty string!")
Expand Down
19 changes: 13 additions & 6 deletions vectordb_bench/frontend/components/check_results/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def getFilterTasks(
task
for task in tasks
if task.task_config.db_name in dbNames
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
and task.task_config.case_config.case_id.case_cls(
task.task_config.case_config.custom_case
).name
in caseNames
]
return filterTasks

Expand All @@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
version = task.task_config.db_config.version or ""
case = task.task_config.case_config.case_id.case_cls(
task.task_config.case_config.custom_case
)
dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"version": version,
"metrics": mergeMetrics(
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
dbCaseMetricsMap[db_name][case.name].get(
"label", ResultLabel.FAILED),
dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
task.label,
),
}
Expand All @@ -57,13 +63,15 @@ def mergeTasks(tasks: list[CaseResult]):
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
version = metricInfo["version"]
label = metricInfo["label"]
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
"db_name": db_name,
"db": db,
"db_label": db_label,
"version": version,
"case_name": case_name,
"metricsSet": set(metrics.keys()),
**metrics,
Expand All @@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
getBetterMetric(
key, value, metrics[key]) if key in metrics else value
getBetterMetric(key, value, metrics[key]) if key in metrics else value
)

return metrics
Expand Down
52 changes: 37 additions & 15 deletions vectordb_bench/frontend/components/run_test/dbConfigSetting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pydantic import ValidationError
from vectordb_bench.frontend.config.styles import *
from vectordb_bench.backend.clients import DB
from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS
from vectordb_bench.frontend.utils import inputIsPassword


def dbConfigSettings(st, activedDbList):
def dbConfigSettings(st, activedDbList: list[DB]):
expander = st.expander("Configurations for the selected databases", True)

dbConfigs = {}
Expand All @@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList):
return dbConfigs, isAllValid


def dbConfigSettingItem(st, activeDb):
def dbConfigSettingItem(st, activeDb: DB):
st.markdown(
f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
unsafe_allow_html=True,
Expand All @@ -36,20 +37,41 @@ def dbConfigSettingItem(st, activeDb):

dbConfigClass = activeDb.config_cls
properties = dbConfigClass.schema().get("properties")
propertiesItems = list(properties.items())
moveDBLabelToLast(propertiesItems)
dbConfig = {}
for j, property in enumerate(propertiesItems):
column = columns[j % DB_CONFIG_SETTING_COLUMNS]
key, value = property
idx = 0

# db config (unique)
for key, property in properties.items():
if (
key not in dbConfigClass.common_short_configs()
and key not in dbConfigClass.common_long_configs()
):
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
idx += 1
dbConfig[key] = column.text_input(
key,
key="%s-%s" % (activeDb.name, key),
value=property.get("default", ""),
type="password" if inputIsPassword(key) else "default",
)
# db config (common short labels)
for key in dbConfigClass.common_short_configs():
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
idx += 1
dbConfig[key] = column.text_input(
key,
key="%s-%s" % (activeDb, key),
value=value.get("default", ""),
type="password" if inputIsPassword(key) else "default",
key="%s-%s" % (activeDb.name, key),
value="",
type="default",
placeholder="optional, for labeling results",
)
return dbConfig


def moveDBLabelToLast(propertiesItems):
propertiesItems.sort(key=lambda x: 1 if x[0] == "db_label" else 0)
# db config (common long text_input)
for key in dbConfigClass.common_long_configs():
dbConfig[key] = st.text_area(
key,
key="%s-%s" % (activeDb.name, key),
value="",
placeholder="optional",
)
return dbConfig
4 changes: 3 additions & 1 deletion vectordb_bench/frontend/components/run_test/initStyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def initStyle(st):
div[data-testid='stHorizontalBlock'] {gap: 8px;}
/* check box */
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
/* db selector - db_name should not wrap */
div[data-testid="stVerticalBlockBorderWrapper"] div[data-testid="stCheckbox"] div[data-testid="stWidgetLabel"] p { white-space: nowrap; }
</style>""",
unsafe_allow_html=True,
)
)
12 changes: 8 additions & 4 deletions vectordb_bench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import pathlib
from datetime import date
from enum import Enum, StrEnum, auto
from typing import List, Self, Sequence, Set
from typing import List, Self

import ujson

from .backend.clients import (
DB,
DBConfig,
DBCaseConfig,
IndexType,
)
from .backend.cases import CaseType
from .base import BaseModel
Expand Down Expand Up @@ -128,9 +127,14 @@ class TaskConfig(BaseModel):

@property
def db_name(self):
db = self.db.value
db_name = f"{self.db.value}"
db_label = self.db_config.db_label
return f"{db}-{db_label}" if db_label else db
if db_label:
db_name += f"-{db_label}"
version = self.db_config.version
if version:
db_name += f"-{version}"
return db_name


class ResultLabel(Enum):
Expand Down

0 comments on commit 5ef1daf

Please sign in to comment.