Skip to content

Commit

Permalink
Add pgdiskann client
Browse files Browse the repository at this point in the history
  • Loading branch information
wahajali committed Aug 28, 2024
1 parent bea6875 commit 3da8a9d
Show file tree
Hide file tree
Showing 7 changed files with 569 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| pgvectorscale | `pip install vectordb-bench[pgvectorscale]` |
| pgdiskann | `pip install vectordb-bench[pgdiskann]` |
| redis | `pip install vectordb-bench[redis]` |
| memorydb | `pip install vectordb-bench[memorydb]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvectorscale = [ "psycopg", "psycopg-binary", "pgvector" ]
pgdiskann = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.1" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
Expand Down
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
PgVectorScale = "PgVectorScale"
PgDiskANN = "PgDiskANN"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
Expand Down Expand Up @@ -77,6 +78,10 @@ def init_cls(self) -> Type[VectorDB]:
from .pgvectorscale.pgvectorscale import PgVectorScale
return PgVectorScale

if self == DB.PgDiskANN:
from .pgdiskann.pgdiskann import PgDiskANN
return PgDiskANN

if self == DB.Redis:
from .redis.redis import Redis
return Redis
Expand Down Expand Up @@ -132,6 +137,10 @@ def config_cls(self) -> Type[DBConfig]:
from .pgvectorscale.config import PgVectorScaleConfig
return PgVectorScaleConfig

if self == DB.PgDiskANN:
from .pgdiskann.config import PgDiskANNConfig
return PgDiskANNConfig

if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig
Expand Down Expand Up @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvectorscale.config import _pgvectorscale_case_config
return _pgvectorscale_case_config.get(index_type)

if self == DB.PgDiskANN:
from .pgdiskann.config import _pgdiskann_case_config
return _pgdiskann_case_config.get(index_type)

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
145 changes: 145 additions & 0 deletions vectordb_bench/backend/clients/pgdiskann/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from abc import abstractmethod
from typing import Any, Mapping, Optional, Sequence, TypedDict
from pydantic import BaseModel, SecretStr
from typing_extensions import LiteralString
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"


class PgDiskANNConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
so the names must match exactly psycopg API"""

user: str
password: str
host: str
port: int
dbname: str


class PgDiskANNConfig(DBConfig):
user_name: SecretStr = SecretStr("postgres")
password: SecretStr
host: str = "localhost"
port: int = 5432
db_name: str

def to_dict(self) -> PgDiskANNConfigDict:
user_str = self.user_name.get_secret_value()
pwd_str = self.password.get_secret_value()
return {
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"user": user_str,
"password": pwd_str,
}


class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
create_index_before_load: bool = False
create_index_after_load: bool = True
maintenance_work_mem: Optional[str]
max_parallel_workers: Optional[int]

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"

def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
elif self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"

@abstractmethod
def index_param(self) -> dict:
...

@abstractmethod
def search_param(self) -> dict:
...

@abstractmethod
def session_param(self) -> dict:
...

@staticmethod
def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
options = []
for option_name, value in with_options.items():
if value is not None:
options.append(
{
"option_name": option_name,
"val": str(value),
}
)
return options

@staticmethod
def _optionally_build_set_options(
set_mapping: Mapping[str, Any]
) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
{"parameter": {
"setting_name": setting_name,
"val": str(value),
},
}
)
return session_options


class PgDiskANNImplConfig(PgDiskANNIndexConfig):
index: IndexType = IndexType.DISKANN
max_neighbors: int | None
l_value_ib: int | None
l_value_is: float | None
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"options": {
"max_neighbors": self.max_neighbors,
"l_value_ib": self.l_value_ib,
},
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
}

def search_param(self) -> dict:
return {
"metric": self.parse_metric(),
"metric_fun_op": self.parse_metric_fun_op(),
}

def session_param(self) -> dict:
return {
"diskann.l_value_is": self.l_value_is,
}

_pgdiskann_case_config = {
IndexType.DISKANN: PgDiskANNImplConfig,
}
Loading

0 comments on commit 3da8a9d

Please sign in to comment.