diff --git a/README.md b/README.md index b779af11f..a086e5d1f 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ Options: --m INTEGER hnsw m --ef-construction INTEGER hnsw ef-construction --ef-search INTEGER hnsw ef-search - --quantization-type [none|halfvec] + --quantization-type [none|bit|halfvec] quantization type for vectors --custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K --custom-case-description TEXT Custom name description diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index faa36712d..0c26fdd3b 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -10,6 +10,8 @@ class MetricType(str, Enum): L2 = "L2" COSINE = "COSINE" IP = "IP" + HAMMING = "HAMMING" + JACCARD = "JACCARD" class IndexType(str, Enum): diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index d5779caf6..cde125cc1 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -4,6 +4,8 @@ import os from pydantic import SecretStr +from vectordb_bench.backend.clients.api import MetricType + from ....cli.cli import ( CommonTypedDict, HNSWFlavor1, @@ -16,6 +18,13 @@ from vectordb_bench.backend.clients import DB + +def set_default_quantized_fetch_limit(ctx, param, value): + if ctx.params.get("reranking") and value is None: + # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search. + return ctx.params["ef_search"] + return value + class PgVectorTypedDict(CommonTypedDict): user_name: Annotated[ str, click.option("--user-name", type=str, help="Db username", required=True) @@ -61,11 +70,45 @@ class PgVectorTypedDict(CommonTypedDict): Optional[str], click.option( "--quantization-type", - type=click.Choice(["none", "halfvec"]), + type=click.Choice(["none", "bit", "halfvec"]), help="quantization type for vectors", required=False, ), ] + reranking: Annotated[ + Optional[bool], + click.option( + "--reranking/--skip-reranking", + type=bool, + help="Enable reranking for HNSW search for binary quantization", + default=False, + ), + ] + reranking_metric: Annotated[ + Optional[str], + click.option( + "--reranking-metric", + type=click.Choice( + [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]] + ), + help="Distance metric for reranking", + default="COSINE", + show_default=True, + ), + ] + quantized_fetch_limit: Annotated[ + Optional[int], + click.option( + "--quantized-fetch-limit", + type=int, + help="Limit of fetching quantized vector ranked by distance for reranking \ + -- bound by ef_search", + required=False, + callback=set_default_quantized_fetch_limit, + ) + ] + + class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): ... @@ -126,6 +169,9 @@ def PgVectorHNSW( maintenance_work_mem=parameters["maintenance_work_mem"], max_parallel_workers=parameters["max_parallel_workers"], quantization_type=parameters["quantization_type"], + reranking=parameters["reranking"], + reranking_metric=parameters["reranking_metric"], + quantized_fetch_limit=parameters["quantized_fetch_limit"], ), **parameters, ) diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 857211234..31d832f13 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -65,6 +65,10 @@ def parse_metric(self) -> str: elif self.metric_type == MetricType.IP: return "halfvec_ip_ops" return "halfvec_cosine_ops" + elif self.quantization_type == "bit": + if self.metric_type == MetricType.JACCARD: + return "bit_jaccard_ops" + return "bit_hamming_ops" else: if self.metric_type == MetricType.L2: return "vector_l2_ops" @@ -73,11 +77,16 @@ def parse_metric(self) -> str: return "vector_cosine_ops" def parse_metric_fun_op(self) -> LiteralString: - if self.metric_type == MetricType.L2: - return "<->" - elif self.metric_type == MetricType.IP: - return "<#>" - return "<=>" + if self.quantization_type == "bit": + if self.metric_type == MetricType.JACCARD: + return "<%>" + return "<~>" + else: + if self.metric_type == MetricType.L2: + return "<->" + elif self.metric_type == MetricType.IP: + return "<#>" + return "<=>" def parse_metric_fun_str(self) -> str: if self.metric_type == MetricType.L2: @@ -85,6 +94,14 @@ def parse_metric_fun_str(self) -> str: elif self.metric_type == MetricType.IP: return "max_inner_product" return "cosine_distance" + + def parse_reranking_metric_fun_op(self) -> LiteralString: + if self.reranking_metric == MetricType.L2: + return "<->" + elif self.reranking_metric == MetricType.IP: + return "<#>" + return "<=>" + @abstractmethod def index_param(self) -> PgVectorIndexParam: @@ -195,6 +212,9 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): maintenance_work_mem: Optional[str] = None max_parallel_workers: Optional[int] = None quantization_type: Optional[str] = None + reranking: Optional[bool] = None + quantized_fetch_limit: Optional[int] = None + reranking_metric: Optional[str] = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} @@ -214,6 +234,9 @@ def index_param(self) -> PgVectorIndexParam: def search_param(self) -> PgVectorSearchParam: return { "metric_fun_op": self.parse_metric_fun_op(), + "reranking": self.reranking, + "reranking_metric_fun_op": self.parse_reranking_metric_fun_op(), + "quantized_fetch_limit": self.quantized_fetch_limit, } def session_param(self) -> PgVectorSessionCommands: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 8123acf18..069b89381 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -11,7 +11,7 @@ from psycopg import Connection, Cursor, sql from ..api import VectorDB -from .config import PgVectorConfigDict, PgVectorIndexConfig +from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig log = logging.getLogger(__name__) @@ -87,6 +87,92 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: assert cursor is not None, "Cursor is not initialized" return conn, cursor + + def _generate_search_query(self, filtered: bool=False) -> sql.Composed: + index_param = self.case_config.index_param() + reranking = self.case_config.search_param()["reranking"] + column_name = ( + sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding")) + if index_param["quantization_type"] == "bit" + else sql.SQL("embedding") + ) + search_vector = ( + sql.SQL("binary_quantize({0})").format(sql.Placeholder()) + if index_param["quantization_type"] == "bit" + else sql.Placeholder() + ) + + # The following sections assume that the quantization_type value matches the quantization function name + if index_param["quantization_type"] != None: + if index_param["quantization_type"] == "bit" and reranking: + # Embeddings needs to be passed to binary_quantize function if quantization_type is bit + search_query = sql.Composed( + [ + sql.SQL( + """ + SELECT i.id + FROM ( + SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance + FROM public.{table_name} {where_clause} + ORDER BY {column_name}::{quantization_type}({dim}) + """ + ).format( + table_name=sql.Identifier(self.table_name), + column_name=column_name, + reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]), + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL( + """ + {search_vector} + LIMIT {quantized_fetch_limit} + ) i + ORDER BY i.distance + LIMIT %s::int + """ + ).format( + search_vector=search_vector, + quantized_fetch_limit=sql.Literal( + self.case_config.search_param()["quantized_fetch_limit"] + ), + ), + ] + ) + else: + search_query = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) " + ).format( + table_name=sql.Identifier(self.table_name), + column_name=column_name, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector), + ] + ) + else: + search_query = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " + ).format( + table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) + + return search_query + @contextmanager def init(self) -> Generator[None, None, None]: @@ -112,63 +198,8 @@ def init(self) -> Generator[None, None, None]: self.cursor.execute(command) self.conn.commit() - index_param = self.case_config.index_param() - # The following sections assume that the quantization_type value matches the quantization function name - if index_param["quantization_type"] != None: - self._filtered_search = sql.Composed( - [ - sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) " - ).format( - table_name=sql.Identifier(self.table_name), - quantization_type=sql.SQL(index_param["quantization_type"]), - dim=sql.Literal(self.dim), - ), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format( - quantization_type=sql.SQL(index_param["quantization_type"]), - dim=sql.Literal(self.dim), - ), - ] - ) - else: - self._filtered_search = sql.Composed( - [ - sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " - ).format(table_name=sql.Identifier(self.table_name)), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) - - if index_param["quantization_type"] != None: - self._unfiltered_search = sql.Composed( - [ - sql.SQL( - "SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) " - ).format( - table_name=sql.Identifier(self.table_name), - quantization_type=sql.SQL(index_param["quantization_type"]), - dim=sql.Literal(self.dim), - ), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format( - quantization_type=sql.SQL(index_param["quantization_type"]), - dim=sql.Literal(self.dim), - ), - ] - ) - else: - self._unfiltered_search = sql.Composed( - [ - sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( - sql.Identifier(self.table_name) - ), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) + self._filtered_search = self._generate_search_query(filtered=True) + self._unfiltered_search = self._generate_search_query() try: yield @@ -306,12 +337,17 @@ def _create_index(self): if index_param["quantization_type"] != None: index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} - USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric}) + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric}) """ ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), + column_name=( + sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding")) + if index_param["quantization_type"] == "bit" + else sql.Identifier("embedding") + ), index_type=sql.Identifier(index_param["index_type"]), # This assumes that the quantization_type value matches the quantization function name quantization_type=sql.SQL(index_param["quantization_type"]), @@ -406,15 +442,28 @@ def search_embedding( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() q = np.asarray(query) if filters: gt = filters.get("id") - result = self.cursor.execute( + if index_param["quantization_type"] == "bit" and search_param["reranking"]: + result = self.cursor.execute( + self._filtered_search, (q, gt, q, k), prepare=True, binary=True + ) + else: + result = self.cursor.execute( self._filtered_search, (gt, q, k), prepare=True, binary=True - ) + ) + else: - result = self.cursor.execute( + if index_param["quantization_type"] == "bit" and search_param["reranking"]: + result = self.cursor.execute( + self._unfiltered_search, (q, q, k), prepare=True, binary=True + ) + else: + result = self.cursor.execute( self._unfiltered_search, (q, k), prepare=True, binary=True - ) + ) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py index 5597bbc61..b25618271 100644 --- a/vectordb_bench/frontend/components/run_test/caseSelector.py +++ b/vectordb_bench/frontend/components/run_test/caseSelector.py @@ -110,6 +110,12 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active value=config.inputConfig["value"], help=config.inputHelp, ) + elif config.inputType == InputType.Bool: + caseConfig[config.label] = column.checkbox( + config.displayLabel if config.displayLabel else config.label.value, + value=config.inputConfig["value"], + help=config.inputHelp, + ) k += 1 if k == 0: columns[1].write("Auto") diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 68bf83f19..ce64de2b8 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from vectordb_bench.backend.cases import CaseLabel, CaseType from vectordb_bench.backend.clients import DB -from vectordb_bench.backend.clients.api import IndexType +from vectordb_bench.backend.clients.api import IndexType, MetricType from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs from vectordb_bench.models import CaseConfig, CaseConfigParamType @@ -149,6 +149,7 @@ class InputType(IntEnum): Number = 20002 Option = 20003 Float = 20004 + Bool = 20005 class CaseConfigInput(BaseModel): @@ -773,7 +774,7 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.quantizationType, inputType=InputType.Option, inputConfig={ - "options": ["none", "halfvec"], + "options": ["none", "bit", "halfvec"], }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ @@ -819,6 +820,46 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_reranking_PgVector = CaseConfigInput( + label=CaseConfigParamType.reranking, + inputType=InputType.Bool, + displayLabel="Enable Reranking", + inputHelp="Enable if you want to use reranking while performing \ + similarity search in binary quantization", + inputConfig={ + "value": False, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) + == "bit" +) + +CaseConfigParamInput_quantized_fetch_limit_PgVector = CaseConfigInput( + label=CaseConfigParamType.quantizedFetchLimit, + displayLabel="Quantized vector fetch limit", + inputHelp="Limit top-k vectors using the quantized vector comparison --bound by ef_search", + inputType=InputType.Number, + inputConfig={ + "min": 20, + "max": 1000, + "value": 200, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) + == "bit" and config.get(CaseConfigParamType.reranking, False) +) + + +CaseConfigParamInput_reranking_metric_PgVector = CaseConfigInput( + label=CaseConfigParamType.rerankingMetric, + inputType=InputType.Option, + inputConfig={ + "options": [ + metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"] + ], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) + == "bit" and config.get(CaseConfigParamType.reranking, False) +) + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, CaseConfigParamInput_M, @@ -896,6 +937,9 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_QuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, + CaseConfigParamInput_reranking_PgVector, + CaseConfigParamInput_reranking_metric_PgVector, + CaseConfigParamInput_quantized_fetch_limit_PgVector, ] PgVectoRSLoadingConfig = [ diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 7968e3e26..c74881d77 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -47,6 +47,9 @@ class CaseConfigParamType(Enum): probes = "probes" quantizationType = "quantization_type" quantizationRatio = "quantization_ratio" + reranking = "reranking" + rerankingMetric = "reranking_metric" + quantizedFetchLimit = "quantized_fetch_limit" m = "m" nbits = "nbits" intermediate_graph_degree = "intermediate_graph_degree"