From f4c61938d02e61e64a929ffebf5a9cb31a29f467 Mon Sep 17 00:00:00 2001 From: Luca Giacchino Date: Fri, 9 Aug 2024 13:29:03 -0700 Subject: [PATCH] Add quantization option for pgvector with support for halfvec --- README.md | 2 + .../backend/clients/pgvector/cli.py | 16 ++- .../backend/clients/pgvector/config.py | 25 +++- .../backend/clients/pgvector/pgvector.py | 113 +++++++++++++----- .../frontend/config/dbCaseConfigs.py | 15 +++ 5 files changed, 135 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 1ce4564ef..f76ac5c10 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ Options: --m INTEGER hnsw m --ef-construction INTEGER hnsw ef-construction --ef-search INTEGER hnsw ef-search + --quantization-type [none|halfvec] + quantization type for vectors --help Show this message and exit. ``` #### Using a configuration file. diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 31b268231..4e0922694 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -56,7 +56,15 @@ class PgVectorTypedDict(CommonTypedDict): required=False, ), ] - + quantization_type: Annotated[ + Optional[str], + click.option( + "--quantization-type", + type=click.Choice(["none", "halfvec"]), + help="quantization type for vectors", + required=False, + ), + ] class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): ... @@ -79,7 +87,10 @@ def PgVectorIVFFlat( db_name=parameters["db_name"], ), db_case_config=PgVectorIVFFlatConfig( - metric_type=None, lists=parameters["lists"], probes=parameters["probes"] + metric_type=None, + lists=parameters["lists"], + probes=parameters["probes"], + quantization_type=parameters["quantization_type"], ), **parameters, ) @@ -111,6 +122,7 @@ def PgVectorHNSW( ef_search=parameters["ef_search"], maintenance_work_mem=parameters["maintenance_work_mem"], max_parallel_workers=parameters["max_parallel_workers"], + quantization_type=parameters["quantization_type"], ), **parameters, ) diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 496a3b440..857211234 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig): create_index_after_load: bool = True def parse_metric(self) -> str: - if self.metric_type == MetricType.L2: - return "vector_l2_ops" - elif self.metric_type == MetricType.IP: - return "vector_ip_ops" - return "vector_cosine_ops" + if self.quantization_type == "halfvec": + if self.metric_type == MetricType.L2: + return "halfvec_l2_ops" + elif self.metric_type == MetricType.IP: + return "halfvec_ip_ops" + return "halfvec_cosine_ops" + else: + if self.metric_type == MetricType.L2: + return "vector_l2_ops" + elif self.metric_type == MetricType.IP: + return "vector_ip_ops" + return "vector_cosine_ops" def parse_metric_fun_op(self) -> LiteralString: if self.metric_type == MetricType.L2: @@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): index: IndexType = IndexType.ES_IVFFlat maintenance_work_mem: Optional[str] = None max_parallel_workers: Optional[int] = None + quantization_type: Optional[str] = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} + if self.quantization_type == "none": + self.quantization_type = None return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -154,6 +164,7 @@ def index_param(self) -> PgVectorIndexParam: ), "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, + "quantization_type": self.quantization_type, } def search_param(self) -> PgVectorSearchParam: @@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): index: IndexType = IndexType.ES_HNSW maintenance_work_mem: Optional[str] = None max_parallel_workers: Optional[int] = None + quantization_type: Optional[str] = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} + if self.quantization_type == "none": + self.quantization_type = None return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -194,6 +208,7 @@ def index_param(self) -> PgVectorIndexParam: ), "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, + "quantization_type": self.quantization_type, } def search_param(self) -> PgVectorSearchParam: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 102481d8d..8123acf18 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -112,25 +112,63 @@ def init(self) -> Generator[None, None, None]: self.cursor.execute(command) self.conn.commit() - self._filtered_search = sql.Composed( - [ - sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " - ).format(table_name=sql.Identifier(self.table_name)), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) + index_param = self.case_config.index_param() + # The following sections assume that the quantization_type value matches the quantization function name + if index_param["quantization_type"] != None: + self._filtered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) " + ).format( + table_name=sql.Identifier(self.table_name), + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format( + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + ] + ) + else: + self._filtered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " + ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) - self._unfiltered_search = sql.Composed( - [ - sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( - sql.Identifier(self.table_name) - ), - sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) + if index_param["quantization_type"] != None: + self._unfiltered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) " + ).format( + table_name=sql.Identifier(self.table_name), + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format( + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + ] + ) + else: + self._unfiltered_search = sql.Composed( + [ + sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( + sql.Identifier(self.table_name) + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) try: yield @@ -265,17 +303,34 @@ def _create_index(self): else: with_clause = sql.Composed(()) - index_create_sql = sql.SQL( - """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} - USING {index_type} (embedding {embedding_metric}) - """ - ).format( - index_name=sql.Identifier(self._index_name), - table_name=sql.Identifier(self.table_name), - index_type=sql.Identifier(index_param["index_type"]), - embedding_metric=sql.Identifier(index_param["metric"]), - ) + if index_param["quantization_type"] != None: + index_create_sql = sql.SQL( + """ + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric}) + """ + ).format( + index_name=sql.Identifier(self._index_name), + table_name=sql.Identifier(self.table_name), + index_type=sql.Identifier(index_param["index_type"]), + # This assumes that the quantization_type value matches the quantization function name + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=self.dim, + embedding_metric=sql.Identifier(index_param["metric"]), + ) + else: + index_create_sql = sql.SQL( + """ + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING {index_type} (embedding {embedding_metric}) + """ + ).format( + index_name=sql.Identifier(self._index_name), + table_name=sql.Identifier(self.table_name), + index_type=sql.Identifier(index_param["index_type"]), + embedding_metric=sql.Identifier(index_param["metric"]), + ) + index_create_sql_with_with_clause = ( index_create_sql + with_clause ).join(" ") diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 13634eca5..78d1936d7 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -738,6 +738,19 @@ class CaseConfigInput(BaseModel): ], ) +CaseConfigParamInput_QuantizationType_PgVector = CaseConfigInput( + label=CaseConfigParamType.quantizationType, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "halfvec"], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], +) + CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput( label=CaseConfigParamType.quantizationRatio, inputType=InputType.Option, @@ -831,6 +844,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Lists_PgVector, CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVector, + CaseConfigParamInput_QuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, ] @@ -841,6 +855,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_EFSearch_PgVector, CaseConfigParamInput_Lists_PgVector, CaseConfigParamInput_Probes_PgVector, + CaseConfigParamInput_QuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, ]