Skip to content

Commit

Permalink
Add support for HNSW in pgvector
Browse files Browse the repository at this point in the history
  • Loading branch information
wahajali committed Mar 19, 2024
1 parent 43302a9 commit fe6594f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 9 deletions.
4 changes: 2 additions & 2 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
return WeaviateIndexConfig

if self == DB.PgVector:
from .pgvector.config import PgVectorIndexConfig
return PgVectorIndexConfig
from .pgvector.config import _pgvector_case_config
return _pgvector_case_config.get(index_type)

if self == DB.PgVectoRS:
from .pgvecto_rs.config import _pgvecto_rs_case_config
Expand Down
47 changes: 44 additions & 3 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, SecretStr
from ..api import DBConfig, DBCaseConfig, MetricType
from ..api import DBConfig, DBCaseConfig, IndexType, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"

Expand All @@ -23,6 +23,7 @@ def to_dict(self) -> dict:

class PgVectorIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
index: IndexType
lists: int | None = 1000
probes: int | None = 10

Expand All @@ -47,15 +48,55 @@ def parse_metric_fun_str(self) -> str:
return "max_inner_product"
return "cosine_distance"



class HNSWConfig(PgVectorIndexConfig):
M: int
efConstruction: int
ef: int | None = None
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"params": {"M": self.M, "efConstruction": self.efConstruction},
}

def index_param(self) -> dict:
return {
"m" : self.M,
"efConstruction" : self.efConstruction,
"metric" : self.parse_metric()
}

def search_param(self) -> dict:
return {
"ef" : self.ef,
"metric_fun" : self.parse_metric_fun_str(),
"metric_fun_op" : self.parse_metric_fun_op(),
}


class IVFFlatConfig(PgVectorIndexConfig):
lists: int
probes: int | None = None
index: IndexType = IndexType.IVFFlat

def index_param(self) -> dict:
return {
"lists" : self.lists,
"metric" : self.parse_metric()
}

def search_param(self) -> dict:
return {
"probes" : self.probes,
"metric_fun" : self.parse_metric_fun_str(),
"metric_fun_op" : self.parse_metric_fun_op(),
}
}

_pgvector_case_config = {
IndexType.HNSW: HNSWConfig,
IndexType.IVFFlat: IVFFlatConfig,
}
22 changes: 18 additions & 4 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import psycopg2
import psycopg2.extras

from ..api import VectorDB, DBCaseConfig
from ..api import IndexType, VectorDB, DBCaseConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +108,14 @@ def _create_index(self):
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});')
if self.case_config.index == IndexType.HNSW:
log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}')
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING hnsw (embedding {index_param["metric"]}) WITH (m={index_param["m"]}, ef_construction={index_param["ef_construction"]});')
elif self.case_config.index == IndexType.IVFFlat:
log.debug(f'Creating IVFFLAT index. list={index_param["lists"]}')
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});')
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()

def _create_table(self, dim : int):
Expand Down Expand Up @@ -164,8 +171,15 @@ def search_embedding(
assert self.cursor is not None, "Cursor is not initialized"

search_param =self.case_config.search_param()
self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")

if self.case_config.index == IndexType.HNSW:
self.cursor.execute(f'SET hnsw.ef_search = {search_param["ef"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
elif self.case_config.index == IndexType.IVFFlat:
self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()
result = self.cursor.fetchall()

Expand Down

0 comments on commit fe6594f

Please sign in to comment.