diff --git a/README.md b/README.md index 893c0497..d00a1304 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,11 @@ Options: --ef-construction INTEGER hnsw ef-construction --ef-search INTEGER hnsw ef-search --quantization-type [none|bit|halfvec] - quantization type for vectors + quantization type for vectors (in index) + --table-quantization-type [none|bit|halfvec] + quantization type for vectors (in table). If + equal to bit, the parameter + quantization_type will be set to bit too. --custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K --custom-case-description TEXT Custom name description --custom-case-load-timeout INTEGER diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 43385ee6..5d90941e 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -72,7 +72,17 @@ class PgVectorTypedDict(CommonTypedDict): click.option( "--quantization-type", type=click.Choice(["none", "bit", "halfvec"]), - help="quantization type for vectors", + help="quantization type for vectors (in index)", + required=False, + ), + ] + table_quantization_type: Annotated[ + Optional[str], + click.option( + "--table-quantization-type", + type=click.Choice(["none", "bit", "halfvec"]), + help="quantization type for vectors (in table). " + "If equal to bit, the parameter quantization_type will be set to bit too.", required=False, ), ] @@ -137,6 +147,7 @@ def PgVectorIVFFlat( lists=parameters["lists"], probes=parameters["probes"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], @@ -173,6 +184,7 @@ def PgVectorHNSW( maintenance_work_mem=parameters["maintenance_work_mem"], max_parallel_workers=parameters["max_parallel_workers"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 16d54744..3443ab03 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -168,14 +168,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): maintenance_work_mem: Optional[str] = None max_parallel_workers: Optional[int] = None quantization_type: Optional[str] = None + table_quantization_type: Optional[str] = None reranking: Optional[bool] = None quantized_fetch_limit: Optional[int] = None reranking_metric: Optional[str] = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type == None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type == None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -185,6 +190,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: @@ -218,14 +224,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): maintenance_work_mem: Optional[str] = None max_parallel_workers: Optional[int] = None quantization_type: Optional[str] = None + table_quantization_type: Optional[str] = None reranking: Optional[bool] = None quantized_fetch_limit: Optional[int] = None reranking_metric: Optional[str] = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type == None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type == None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -235,6 +246,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 069b8938..e211cd9f 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -93,7 +93,7 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: reranking = self.case_config.search_param()["reranking"] column_name = ( sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding")) - if index_param["quantization_type"] == "bit" + if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit" else sql.SQL("embedding") ) search_vector = ( @@ -103,7 +103,8 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: ) # The following sections assume that the quantization_type value matches the quantization function name - if index_param["quantization_type"] != None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: + # Reranking makes sense only if table quantization is not "bit" if index_param["quantization_type"] == "bit" and reranking: # Embeddings needs to be passed to binary_quantize function if quantization_type is bit search_query = sql.Composed( @@ -112,7 +113,7 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: """ SELECT i.id FROM ( - SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance + SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) """ @@ -120,6 +121,8 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: table_name=sql.Identifier(self.table_name), column_name=column_name, reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]), + search_vector=search_vector, + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), quantization_type=sql.SQL(index_param["quantization_type"]), dim=sql.Literal(self.dim), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), @@ -127,7 +130,7 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL( """ - {search_vector} + {search_vector}::{quantization_type}({dim}) LIMIT {quantized_fetch_limit} ) i ORDER BY i.distance @@ -135,6 +138,8 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: """ ).format( search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), quantized_fetch_limit=sql.Literal( self.case_config.search_param()["quantized_fetch_limit"] ), @@ -154,7 +159,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector), + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( + search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), ] ) else: @@ -167,7 +176,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( + search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), ] ) @@ -334,7 +347,7 @@ def _create_index(self): else: with_clause = sql.Composed(()) - if index_param["quantization_type"] != None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: index_create_sql = sql.SQL( """ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} @@ -377,6 +390,8 @@ def _create_index(self): def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() try: log.info(f"{self.name} client create table : {self.table_name}") @@ -384,8 +399,11 @@ def _create_table(self, dim: int): # create table self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" - ).format(table_name=sql.Identifier(self.table_name), dim=dim) + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));" + ).format( + table_name=sql.Identifier(self.table_name), + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), + dim=dim) ) self.cursor.execute( sql.SQL( @@ -407,19 +425,42 @@ def insert_embeddings( ) -> Tuple[int, Optional[Exception]]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() try: metadata_arr = np.array(metadata) embeddings_arr = np.array(embeddings) - - with self.cursor.copy( - sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) - ) as copy: - copy.set_types(["bigint", "vector"]) - for i, row in enumerate(metadata_arr): - copy.write_row((row, embeddings_arr[i])) + + if index_param["table_quantization_type"] == "bit": + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + # Same logic as pgvector binary_quantize + for i, row in enumerate(metadata_arr): + embeddings_bit = '' + for embedding in embeddings_arr[i]: + if embedding > 0: + embeddings_bit += '1' + else: + embeddings_bit += '0' + copy.write_row((str(row), embeddings_bit)) + else: + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + if index_param["table_quantization_type"] == "halfvec": + copy.set_types(["bigint", "halfvec"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, np.float16(embeddings_arr[i]))) + else: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() if kwargs.get("last_batch"): diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 7f076a2d..aa17ab02 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -854,6 +854,19 @@ class CaseConfigInput(BaseModel): ], ) +CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput( + label=CaseConfigParamType.tableQuantizationType, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "bit", "halfvec"], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], +) + CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput( label=CaseConfigParamType.max_parallel_workers, displayLabel="Max parallel workers", @@ -1149,6 +1162,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, ] @@ -1160,6 +1174,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Lists_PgVector, CaseConfigParamInput_Probes_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, CaseConfigParamInput_reranking_PgVector, diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 648fb172..ef305b9c 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -47,6 +47,7 @@ class CaseConfigParamType(Enum): probes = "probes" quantizationType = "quantization_type" quantizationRatio = "quantization_ratio" + tableQuantizationType = "table_quantization_type" reranking = "reranking" rerankingMetric = "reranking_metric" quantizedFetchLimit = "quantized_fetch_limit"