Skip to content

Commit

Permalink
fix: new interface of pgvecto.rs after v0.2
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <[email protected]>
  • Loading branch information
cutecutecat committed Apr 16, 2024
1 parent c50123e commit 6f1cf45
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ qdrant = [ "qdrant-client" ]
pinecone = [ "pinecone-client" ]
weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "pgvector", "sqlalchemy" ]
pgvector = [ "pgvector", "psycopg2" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
chromadb = [ "chromadb" ]
Expand Down
76 changes: 44 additions & 32 deletions vectordb_bench/backend/clients/pgvecto_rs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,30 @@
class PgVectoRSConfig(DBConfig):
user_name: SecretStr = "postgres"
password: SecretStr
url: SecretStr
host: str = "localhost"
port: int = 5432
db_name: str

def to_dict(self) -> dict:
user_str = self.user_name.get_secret_value()
pwd_str = self.password.get_secret_value()
url_str = self.url.get_secret_value()
host, port = url_str.split(":")
return {
"host": host,
"port": port,
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"user": user_str,
"password": pwd_str,
"password": pwd_str
}


class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
quantizationType: Literal["trivial", "scalar", "product"]
quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]

def parse_quantization(self) -> str:
if self.quantizationType == "trivial":
return "quantization = { trivial = { } }"
elif self.quantizationType == "scalar":
return "quantization = { scalar = { } }"
else:
return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_ops"
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "dot_ops"
return "cosine_ops"
return "vector_dot_ops"
return "vector_cos_ops"

def parse_metric_fun_op(self) -> str:
if self.metric_type == MetricType.L2:
Expand All @@ -52,16 +40,27 @@ def parse_metric_fun_op(self) -> str:
return "<#>"
return "<=>"

class PgVectoRSQuantConfig(PgVectoRSIndexConfig):
quantizationType: Literal["trivial", "scalar", "product"]
quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]

class HNSWConfig(PgVectoRSIndexConfig):
def parse_quantization(self) -> str:
if self.quantizationType == "trivial":
return "quantization = { trivial = { } }"
elif self.quantizationType == "scalar":
return "quantization = { scalar = { } }"
else:
return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'


class HNSWConfig(PgVectoRSQuantConfig):
M: int
efConstruction: int
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.hnsw]
[indexing.hnsw]
m = {self.M}
ef_construction = {self.efConstruction}
{self.parse_quantization()}
Expand All @@ -72,32 +71,46 @@ def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}


class IVFFlatConfig(PgVectoRSIndexConfig):
class IVFFlatConfig(PgVectoRSQuantConfig):
nlist: int
nprobe: int | None = None
index: IndexType = IndexType.IVFFlat

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.ivf]
[indexing.ivf]
nlist = {self.nlist}
nprob = {self.nprobe if self.nprobe else 10}
nsample = {self.nprobe if self.nprobe else 10}
{self.parse_quantization()}
"""
return {"options": options, "metric": self.parse_metric()}

def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}

class IVFFlatSQ8Config(PgVectoRSIndexConfig):
nlist: int
nprobe: int | None = None
index: IndexType = IndexType.IVFSQ8

def index_param(self) -> dict:
options = f"""
[indexing.ivf]
nlist = {self.nlist}
nsample = {self.nprobe if self.nprobe else 10}
quantization = {{ scalar = {{ }} }}
"""
return {"options": options, "metric": self.parse_metric()}

def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}

class FLATConfig(PgVectoRSIndexConfig):
class FLATConfig(PgVectoRSQuantConfig):
index: IndexType = IndexType.Flat

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.flat]
[indexing.flat]
{self.parse_quantization()}
"""
return {"options": options, "metric": self.parse_metric()}
Expand All @@ -107,9 +120,8 @@ def search_param(self) -> dict:


_pgvecto_rs_case_config = {
IndexType.AUTOINDEX: HNSWConfig,
IndexType.HNSW: HNSWConfig,
IndexType.DISKANN: HNSWConfig,
IndexType.IVFFlat: IVFFlatConfig,
IndexType.IVFSQ8: IVFFlatSQ8Config,
IndexType.Flat: FLATConfig,
}
32 changes: 16 additions & 16 deletions vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Wrapper around the Pgvector vector database over VectorDB"""
"""Wrapper around the Pgvecto.rs vector database over VectorDB"""

import io
import logging
from contextlib import contextmanager
from typing import Any
import pandas as pd

import psycopg2
import psycopg2.extras

from ..api import VectorDB, DBCaseConfig

log = logging.getLogger(__name__)


class PgVectoRS(VectorDB):
"""Use SQLAlchemy instructions"""

Expand Down Expand Up @@ -66,6 +65,8 @@ def init(self) -> None:
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()
self.cursor.execute('SET search_path = "$user", public, vectors')
self.conn.commit()

try:
yield
Expand Down Expand Up @@ -113,7 +114,7 @@ def _create_index(self):
self.conn.commit()
except Exception as e:
log.warning(
f"Failed to create pgvector table: {self.table_name} error: {e}"
f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
)
raise e from None

Expand All @@ -127,13 +128,10 @@ def _create_table(self, dim: int):
f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \
(id Integer PRIMARY KEY, embedding vector({dim}));'
)
self.cursor.execute(
f'ALTER TABLE public."{self.table_name}" ALTER COLUMN embedding SET STORAGE PLAIN;'
)
self.conn.commit()
except Exception as e:
log.warning(
f"Failed to create pgvector table: {self.table_name} error: {e}"
f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
)
raise e from None

Expand All @@ -146,22 +144,24 @@ def insert_embeddings(
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
items = {"id": metadata, "embedding": embeddings}
items = {
"id": metadata,
"embedding": embeddings
}
df = pd.DataFrame(items)
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False, header=False)
csv_buffer.seek(0)
self.cursor.copy_expert(
f'COPY public."{self.table_name}" FROM STDIN WITH (FORMAT CSV)',
csv_buffer,
)
self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
self.conn.commit()
return len(metadata), None
except Exception as e:
log.warning(
f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
)
log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}")
return 0, e

def search_embedding(
self,
Expand Down
11 changes: 10 additions & 1 deletion vectordb_bench/frontend/const/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": ["trivial", "scalar", "product"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput(
Expand All @@ -406,7 +411,11 @@ class CaseConfigInput(BaseModel):
"options": ["x4", "x8", "x16", "x32", "x64"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
== "product",
== "product" and config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_ZillizLevel = CaseConfigInput(
Expand Down

0 comments on commit 6f1cf45

Please sign in to comment.