Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: new interface of pgvecto.rs after v0.2 #301

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ qdrant = [ "qdrant-client" ]
pinecone = [ "pinecone-client" ]
weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "pgvector", "sqlalchemy" ]
pgvector = [ "pgvector", "psycopg2" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
chromadb = [ "chromadb" ]
Expand Down
76 changes: 44 additions & 32 deletions vectordb_bench/backend/clients/pgvecto_rs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,30 @@
class PgVectoRSConfig(DBConfig):
user_name: SecretStr = "postgres"
password: SecretStr
url: SecretStr
host: str = "localhost"
port: int = 5432
db_name: str

def to_dict(self) -> dict:
user_str = self.user_name.get_secret_value()
pwd_str = self.password.get_secret_value()
url_str = self.url.get_secret_value()
host, port = url_str.split(":")
return {
"host": host,
"port": port,
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"user": user_str,
"password": pwd_str,
"password": pwd_str
}


class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
quantizationType: Literal["trivial", "scalar", "product"]
quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]

def parse_quantization(self) -> str:
if self.quantizationType == "trivial":
return "quantization = { trivial = { } }"
elif self.quantizationType == "scalar":
return "quantization = { scalar = { } }"
else:
return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_ops"
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "dot_ops"
return "cosine_ops"
return "vector_dot_ops"
return "vector_cos_ops"

def parse_metric_fun_op(self) -> str:
if self.metric_type == MetricType.L2:
Expand All @@ -52,16 +40,27 @@ def parse_metric_fun_op(self) -> str:
return "<#>"
return "<=>"

class PgVectoRSQuantConfig(PgVectoRSIndexConfig):
quantizationType: Literal["trivial", "scalar", "product"]
quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]

class HNSWConfig(PgVectoRSIndexConfig):
def parse_quantization(self) -> str:
if self.quantizationType == "trivial":
return "quantization = { trivial = { } }"
elif self.quantizationType == "scalar":
return "quantization = { scalar = { } }"
else:
return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'


class HNSWConfig(PgVectoRSQuantConfig):
M: int
efConstruction: int
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.hnsw]
[indexing.hnsw]
m = {self.M}
ef_construction = {self.efConstruction}
{self.parse_quantization()}
Expand All @@ -72,32 +71,46 @@ def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}


class IVFFlatConfig(PgVectoRSIndexConfig):
class IVFFlatConfig(PgVectoRSQuantConfig):
nlist: int
nprobe: int | None = None
index: IndexType = IndexType.IVFFlat

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.ivf]
[indexing.ivf]
nlist = {self.nlist}
nprob = {self.nprobe if self.nprobe else 10}
nsample = {self.nprobe if self.nprobe else 10}
{self.parse_quantization()}
"""
return {"options": options, "metric": self.parse_metric()}

def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}

class IVFFlatSQ8Config(PgVectoRSIndexConfig):
nlist: int
nprobe: int | None = None
index: IndexType = IndexType.IVFSQ8

def index_param(self) -> dict:
options = f"""
[indexing.ivf]
nlist = {self.nlist}
nsample = {self.nprobe if self.nprobe else 10}
quantization = {{ scalar = {{ }} }}
"""
return {"options": options, "metric": self.parse_metric()}

def search_param(self) -> dict:
return {"metrics_op": self.parse_metric_fun_op()}

class FLATConfig(PgVectoRSIndexConfig):
class FLATConfig(PgVectoRSQuantConfig):
index: IndexType = IndexType.Flat

def index_param(self) -> dict:
options = f"""
capacity = 1048576
[algorithm.flat]
[indexing.flat]
{self.parse_quantization()}
"""
return {"options": options, "metric": self.parse_metric()}
Expand All @@ -107,9 +120,8 @@ def search_param(self) -> dict:


_pgvecto_rs_case_config = {
IndexType.AUTOINDEX: HNSWConfig,
IndexType.HNSW: HNSWConfig,
IndexType.DISKANN: HNSWConfig,
IndexType.IVFFlat: IVFFlatConfig,
IndexType.IVFSQ8: IVFFlatSQ8Config,
IndexType.Flat: FLATConfig,
}
32 changes: 16 additions & 16 deletions vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Wrapper around the Pgvector vector database over VectorDB"""
"""Wrapper around the Pgvecto.rs vector database over VectorDB"""

import io
import logging
from contextlib import contextmanager
from typing import Any
import pandas as pd

import psycopg2
import psycopg2.extras

from ..api import VectorDB, DBCaseConfig

log = logging.getLogger(__name__)


class PgVectoRS(VectorDB):
"""Use SQLAlchemy instructions"""

Expand Down Expand Up @@ -66,6 +65,8 @@ def init(self) -> None:
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()
self.cursor.execute('SET search_path = "$user", public, vectors')
self.conn.commit()

try:
yield
Expand Down Expand Up @@ -113,7 +114,7 @@ def _create_index(self):
self.conn.commit()
except Exception as e:
log.warning(
f"Failed to create pgvector table: {self.table_name} error: {e}"
f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
)
raise e from None

Expand All @@ -127,13 +128,10 @@ def _create_table(self, dim: int):
f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \
(id Integer PRIMARY KEY, embedding vector({dim}));'
)
self.cursor.execute(
f'ALTER TABLE public."{self.table_name}" ALTER COLUMN embedding SET STORAGE PLAIN;'
)
self.conn.commit()
except Exception as e:
log.warning(
f"Failed to create pgvector table: {self.table_name} error: {e}"
f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
)
raise e from None

Expand All @@ -146,22 +144,24 @@ def insert_embeddings(
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
items = {"id": metadata, "embedding": embeddings}
items = {
"id": metadata,
"embedding": embeddings
}
df = pd.DataFrame(items)
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False, header=False)
csv_buffer.seek(0)
self.cursor.copy_expert(
f'COPY public."{self.table_name}" FROM STDIN WITH (FORMAT CSV)',
csv_buffer,
)
self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
self.conn.commit()
return len(metadata), None
except Exception as e:
log.warning(
f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
)
log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}")
return 0, e

def search_embedding(
self,
Expand Down
11 changes: 10 additions & 1 deletion vectordb_bench/frontend/const/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": ["trivial", "scalar", "product"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput(
Expand All @@ -406,7 +411,11 @@ class CaseConfigInput(BaseModel):
"options": ["x4", "x8", "x16", "x32", "x64"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
== "product",
== "product" and config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_ZillizLevel = CaseConfigInput(
Expand Down