From fe6594f31b91b082d7711492e055bc6564978999 Mon Sep 17 00:00:00 2001 From: Wahaj Ali Date: Tue, 19 Mar 2024 10:41:25 +0500 Subject: [PATCH] Add support for HNSW in pgvector --- vectordb_bench/backend/clients/__init__.py | 4 +- .../backend/clients/pgvector/config.py | 47 +++++++++++++++++-- .../backend/clients/pgvector/pgvector.py | 22 +++++++-- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 3df11610b..bd5211587 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -142,8 +142,8 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon return WeaviateIndexConfig if self == DB.PgVector: - from .pgvector.config import PgVectorIndexConfig - return PgVectorIndexConfig + from .pgvector.config import _pgvector_case_config + return _pgvector_case_config.get(index_type) if self == DB.PgVectoRS: from .pgvecto_rs.config import _pgvecto_rs_case_config diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 7d90e86d2..3cf19993c 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType +from ..api import DBConfig, DBCaseConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -23,6 +23,7 @@ def to_dict(self) -> dict: class PgVectorIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None + index: IndexType lists: int | None = 1000 probes: int | None = 10 @@ -47,15 +48,55 @@ def parse_metric_fun_str(self) -> str: return "max_inner_product" return "cosine_distance" + + +class HNSWConfig(PgVectorIndexConfig): + M: int + efConstruction: int + ef: int | None = None + index: IndexType = IndexType.HNSW + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": {"M": self.M, "efConstruction": self.efConstruction}, + } + + def index_param(self) -> dict: + return { + "m" : self.M, + "efConstruction" : self.efConstruction, + "metric" : self.parse_metric() + } + + def search_param(self) -> dict: + return { + "ef" : self.ef, + "metric_fun" : self.parse_metric_fun_str(), + "metric_fun_op" : self.parse_metric_fun_op(), + } + + +class IVFFlatConfig(PgVectorIndexConfig): + lists: int + probes: int | None = None + index: IndexType = IndexType.IVFFlat + def index_param(self) -> dict: return { "lists" : self.lists, "metric" : self.parse_metric() } - + def search_param(self) -> dict: return { "probes" : self.probes, "metric_fun" : self.parse_metric_fun_str(), "metric_fun_op" : self.parse_metric_fun_op(), - } \ No newline at end of file + } + +_pgvector_case_config = { + IndexType.HNSW: HNSWConfig, + IndexType.IVFFlat: IVFFlatConfig, +} \ No newline at end of file diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index e0fc8d3b8..a69b7cbde 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -8,7 +8,7 @@ import psycopg2 import psycopg2.extras -from ..api import VectorDB, DBCaseConfig +from ..api import IndexType, VectorDB, DBCaseConfig log = logging.getLogger(__name__) @@ -108,7 +108,14 @@ def _create_index(self): assert self.cursor is not None, "Cursor is not initialized" index_param = self.case_config.index_param() - self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});') + if self.case_config.index == IndexType.HNSW: + log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}') + self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING hnsw (embedding {index_param["metric"]}) WITH (m={index_param["m"]}, ef_construction={index_param["ef_construction"]});') + elif self.case_config.index == IndexType.IVFFlat: + log.debug(f'Creating IVFFLAT index. list={index_param["lists"]}') + self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});') + else: + assert "Invalid index type {self.case_config.index}" self.conn.commit() def _create_table(self, dim : int): @@ -164,8 +171,15 @@ def search_embedding( assert self.cursor is not None, "Cursor is not initialized" search_param =self.case_config.search_param() - self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}') - self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};") + + if self.case_config.index == IndexType.HNSW: + self.cursor.execute(f'SET hnsw.ef_search = {search_param["ef"]}') + self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};") + elif self.case_config.index == IndexType.IVFFlat: + self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}') + self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};") + else: + assert "Invalid index type {self.case_config.index}" self.conn.commit() result = self.cursor.fetchall()