diff --git a/pyproject.toml b/pyproject.toml index a7c5c892b..3a72d15c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ qdrant = [ "qdrant-client" ] pinecone = [ "pinecone-client" ] weaviate = [ "weaviate-client" ] elastic = [ "elasticsearch" ] -pgvector = [ "pgvector", "sqlalchemy" ] +pgvector = [ "pgvector", "psycopg2" ] pgvecto_rs = [ "psycopg2" ] redis = [ "redis" ] chromadb = [ "chromadb" ] diff --git a/vectordb_bench/backend/clients/pgvecto_rs/config.py b/vectordb_bench/backend/clients/pgvecto_rs/config.py index 7ad5983f1..73c41c239 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/config.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/config.py @@ -8,42 +8,30 @@ class PgVectoRSConfig(DBConfig): user_name: SecretStr = "postgres" password: SecretStr - url: SecretStr + host: str = "localhost" + port: int = 5432 db_name: str def to_dict(self) -> dict: user_str = self.user_name.get_secret_value() pwd_str = self.password.get_secret_value() - url_str = self.url.get_secret_value() - host, port = url_str.split(":") return { - "host": host, - "port": port, + "host": self.host, + "port": self.port, "dbname": self.db_name, "user": user_str, - "password": pwd_str, + "password": pwd_str } - class PgVectoRSIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None - quantizationType: Literal["trivial", "scalar", "product"] - quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"] - - def parse_quantization(self) -> str: - if self.quantizationType == "trivial": - return "quantization = { trivial = { } }" - elif self.quantizationType == "scalar": - return "quantization = { scalar = { } }" - else: - return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}' def parse_metric(self) -> str: if self.metric_type == MetricType.L2: - return "l2_ops" + return "vector_l2_ops" elif self.metric_type == MetricType.IP: - return "dot_ops" - return "cosine_ops" + return "vector_dot_ops" + return "vector_cos_ops" def parse_metric_fun_op(self) -> str: if self.metric_type == MetricType.L2: @@ -52,16 +40,27 @@ def parse_metric_fun_op(self) -> str: return "<#>" return "<=>" +class PgVectoRSQuantConfig(PgVectoRSIndexConfig): + quantizationType: Literal["trivial", "scalar", "product"] + quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"] -class HNSWConfig(PgVectoRSIndexConfig): + def parse_quantization(self) -> str: + if self.quantizationType == "trivial": + return "quantization = { trivial = { } }" + elif self.quantizationType == "scalar": + return "quantization = { scalar = { } }" + else: + return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}' + + +class HNSWConfig(PgVectoRSQuantConfig): M: int efConstruction: int index: IndexType = IndexType.HNSW def index_param(self) -> dict: options = f""" -capacity = 1048576 -[algorithm.hnsw] +[indexing.hnsw] m = {self.M} ef_construction = {self.efConstruction} {self.parse_quantization()} @@ -72,17 +71,16 @@ def search_param(self) -> dict: return {"metrics_op": self.parse_metric_fun_op()} -class IVFFlatConfig(PgVectoRSIndexConfig): +class IVFFlatConfig(PgVectoRSQuantConfig): nlist: int nprobe: int | None = None index: IndexType = IndexType.IVFFlat def index_param(self) -> dict: options = f""" -capacity = 1048576 -[algorithm.ivf] +[indexing.ivf] nlist = {self.nlist} -nprob = {self.nprobe if self.nprobe else 10} +nsample = {self.nprobe if self.nprobe else 10} {self.parse_quantization()} """ return {"options": options, "metric": self.parse_metric()} @@ -90,14 +88,29 @@ def index_param(self) -> dict: def search_param(self) -> dict: return {"metrics_op": self.parse_metric_fun_op()} +class IVFFlatSQ8Config(PgVectoRSIndexConfig): + nlist: int + nprobe: int | None = None + index: IndexType = IndexType.IVFSQ8 + + def index_param(self) -> dict: + options = f""" +[indexing.ivf] +nlist = {self.nlist} +nsample = {self.nprobe if self.nprobe else 10} +quantization = {{ scalar = {{ }} }} +""" + return {"options": options, "metric": self.parse_metric()} + + def search_param(self) -> dict: + return {"metrics_op": self.parse_metric_fun_op()} -class FLATConfig(PgVectoRSIndexConfig): +class FLATConfig(PgVectoRSQuantConfig): index: IndexType = IndexType.Flat def index_param(self) -> dict: options = f""" -capacity = 1048576 -[algorithm.flat] +[indexing.flat] {self.parse_quantization()} """ return {"options": options, "metric": self.parse_metric()} @@ -107,9 +120,8 @@ def search_param(self) -> dict: _pgvecto_rs_case_config = { - IndexType.AUTOINDEX: HNSWConfig, IndexType.HNSW: HNSWConfig, - IndexType.DISKANN: HNSWConfig, IndexType.IVFFlat: IVFFlatConfig, + IndexType.IVFSQ8: IVFFlatSQ8Config, IndexType.Flat: FLATConfig, } diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index 22476522c..22caa43e6 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -1,18 +1,17 @@ -"""Wrapper around the Pgvector vector database over VectorDB""" +"""Wrapper around the Pgvecto.rs vector database over VectorDB""" import io import logging from contextlib import contextmanager from typing import Any import pandas as pd - import psycopg2 +import psycopg2.extras from ..api import VectorDB, DBCaseConfig log = logging.getLogger(__name__) - class PgVectoRS(VectorDB): """Use SQLAlchemy instructions""" @@ -66,6 +65,8 @@ def init(self) -> None: self.conn = psycopg2.connect(**self.db_config) self.conn.autocommit = False self.cursor = self.conn.cursor() + self.cursor.execute('SET search_path = "$user", public, vectors') + self.conn.commit() try: yield @@ -113,7 +114,7 @@ def _create_index(self): self.conn.commit() except Exception as e: log.warning( - f"Failed to create pgvector table: {self.table_name} error: {e}" + f"Failed to create pgvecto.rs table: {self.table_name} error: {e}" ) raise e from None @@ -127,13 +128,10 @@ def _create_table(self, dim: int): f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \ (id Integer PRIMARY KEY, embedding vector({dim}));' ) - self.cursor.execute( - f'ALTER TABLE public."{self.table_name}" ALTER COLUMN embedding SET STORAGE PLAIN;' - ) self.conn.commit() except Exception as e: log.warning( - f"Failed to create pgvector table: {self.table_name} error: {e}" + f"Failed to create pgvecto.rs table: {self.table_name} error: {e}" ) raise e from None @@ -146,22 +144,24 @@ def insert_embeddings( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + try: - items = {"id": metadata, "embedding": embeddings} + items = { + "id": metadata, + "embedding": embeddings + } df = pd.DataFrame(items) csv_buffer = io.StringIO() df.to_csv(csv_buffer, index=False, header=False) csv_buffer.seek(0) - self.cursor.copy_expert( - f'COPY public."{self.table_name}" FROM STDIN WITH (FORMAT CSV)', - csv_buffer, - ) + self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer) self.conn.commit() return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into pgvector table ({self.table_name}), error: {e}" - ) + log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}") + return 0, e def search_embedding( self, diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/const/dbCaseConfigs.py index fad5f362d..24de64725 100644 --- a/vectordb_bench/frontend/const/dbCaseConfigs.py +++ b/vectordb_bench/frontend/const/dbCaseConfigs.py @@ -397,6 +397,11 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": ["trivial", "scalar", "product"], }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], ) CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput( @@ -406,7 +411,11 @@ class CaseConfigInput(BaseModel): "options": ["x4", "x8", "x16", "x32", "x64"], }, isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) - == "product", + == "product" and config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], ) CaseConfigParamInput_ZillizLevel = CaseConfigInput(