Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HNSW in pgvector #293

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
return WeaviateIndexConfig

if self == DB.PgVector:
from .pgvector.config import PgVectorIndexConfig
return PgVectorIndexConfig
from .pgvector.config import _pgvector_case_config
return _pgvector_case_config.get(index_type)

if self == DB.PgVectoRS:
from .pgvecto_rs.config import _pgvecto_rs_case_config
Expand Down
40 changes: 37 additions & 3 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, SecretStr
from ..api import DBConfig, DBCaseConfig, MetricType
from ..api import DBConfig, DBCaseConfig, IndexType, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"

Expand All @@ -23,6 +23,7 @@ def to_dict(self) -> dict:

class PgVectorIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
index: IndexType
lists: int | None = 1000
probes: int | None = 10

Expand All @@ -47,15 +48,48 @@ def parse_metric_fun_str(self) -> str:
return "max_inner_product"
return "cosine_distance"



class HNSWConfig(PgVectorIndexConfig):
m: int
ef_construction: int
ef: int | None = None
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
return {
"m" : self.m,
"ef_construction" : self.ef_construction,
"metric" : self.parse_metric()
}

def search_param(self) -> dict:
return {
"ef" : self.ef,
"metric_fun" : self.parse_metric_fun_str(),
"metric_fun_op" : self.parse_metric_fun_op(),
}


class IVFFlatConfig(PgVectorIndexConfig):
lists: int
probes: int | None = None
index: IndexType = IndexType.IVFFlat

def index_param(self) -> dict:
return {
"lists" : self.lists,
"metric" : self.parse_metric()
}

def search_param(self) -> dict:
return {
"probes" : self.probes,
"metric_fun" : self.parse_metric_fun_str(),
"metric_fun_op" : self.parse_metric_fun_op(),
}
}

_pgvector_case_config = {
IndexType.HNSW: HNSWConfig,
IndexType.IVFFlat: IVFFlatConfig,
}
22 changes: 18 additions & 4 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import psycopg2
import psycopg2.extras

from ..api import VectorDB, DBCaseConfig
from ..api import IndexType, VectorDB, DBCaseConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +108,14 @@ def _create_index(self):
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});')
if self.case_config.index == IndexType.HNSW:
log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}.')
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING hnsw (embedding {index_param["metric"]}) WITH (m={index_param["m"]}, ef_construction={index_param["ef_construction"]});')
elif self.case_config.index == IndexType.IVFFlat:
log.debug(f'Creating IVFFLAT index. list={index_param["lists"]}.')
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});')
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()

def _create_table(self, dim : int):
Expand Down Expand Up @@ -164,8 +171,15 @@ def search_embedding(
assert self.cursor is not None, "Cursor is not initialized"

search_param =self.case_config.search_param()
self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")

if self.case_config.index == IndexType.HNSW:
self.cursor.execute(f'SET hnsw.ef_search = {search_param["ef"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
elif self.case_config.index == IndexType.IVFFlat:
self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}')
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()
result = self.cursor.fetchall()

Expand Down