Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for pgdiskann client #388

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| pgvectorscale | `pip install vectordb-bench[pgvectorscale]` |
| pgdiskann | `pip install vectordb-bench[pgdiskann]` |
| redis | `pip install vectordb-bench[redis]` |
| memorydb | `pip install vectordb-bench[memorydb]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvectorscale = [ "psycopg", "psycopg-binary", "pgvector" ]
pgdiskann = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.2" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
Expand Down
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
PgVectorScale = "PgVectorScale"
PgDiskANN = "PgDiskANN"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
Expand Down Expand Up @@ -77,6 +78,10 @@ def init_cls(self) -> Type[VectorDB]:
from .pgvectorscale.pgvectorscale import PgVectorScale
return PgVectorScale

if self == DB.PgDiskANN:
from .pgdiskann.pgdiskann import PgDiskANN
return PgDiskANN

if self == DB.Redis:
from .redis.redis import Redis
return Redis
Expand Down Expand Up @@ -132,6 +137,10 @@ def config_cls(self) -> Type[DBConfig]:
from .pgvectorscale.config import PgVectorScaleConfig
return PgVectorScaleConfig

if self == DB.PgDiskANN:
from .pgdiskann.config import PgDiskANNConfig
return PgDiskANNConfig

if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig
Expand Down Expand Up @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvectorscale.config import _pgvectorscale_case_config
return _pgvectorscale_case_config.get(index_type)

if self == DB.PgDiskANN:
from .pgdiskann.config import _pgdiskann_case_config
return _pgdiskann_case_config.get(index_type)

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
99 changes: 99 additions & 0 deletions vectordb_bench/backend/clients/pgdiskann/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import click
import os
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from typing import Annotated, Optional, Unpack
from vectordb_bench.backend.clients import DB


class PgDiskAnnTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
]
password: Annotated[
str,
click.option("--password",
type=str,
help="Postgres database password",
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
show_default="$POSTGRES_PASSWORD",
),
]

host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
db_name: Annotated[
str, click.option("--db-name", type=str, help="Db name", required=True)
]
max_neighbors: Annotated[
int,
click.option(
"--max-neighbors", type=int, help="PgDiskAnn max neighbors",
),
]
l_value_ib: Annotated[
int,
click.option(
"--l-value-ib", type=int, help="PgDiskAnn l_value_ib",
),
]
l_value_is: Annotated[
float,
click.option(
"--l-value-is", type=float, help="PgDiskAnn l_value_is",
),
]
maintenance_work_mem: Annotated[
Optional[str],
click.option(
"--maintenance-work-mem",
type=str,
help="Sets the maximum memory to be used for maintenance operations (index creation). "
"Can be entered as string with unit like '64GB' or as an integer number of KB."
"This will set the parameters: max_parallel_maintenance_workers,"
" max_parallel_workers & table(parallel_workers)",
Comment on lines +60 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These help text lines (60-61) belong with "--max-parallel-workers", this is also an issue an issue with the pgvectory cli.py.

required=False,
),
]
max_parallel_workers: Annotated[
Optional[int],
click.option(
"--max-parallel-workers",
type=int,
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
required=False,
),
]

@cli.command()
@click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict)
def PgDiskAnn(
**parameters: Unpack[PgDiskAnnTypedDict],
):
from .config import PgDiskANNConfig, PgDiskANNImplConfig

run(
db=DB.PgDiskANN,
db_config=PgDiskANNConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgDiskANNImplConfig(
max_neighbors=parameters["max_neighbors"],
l_value_ib=parameters["l_value_ib"],
l_value_is=parameters["l_value_is"],
max_parallel_workers=parameters["max_parallel_workers"],
maintenance_work_mem=parameters["maintenance_work_mem"],
),
**parameters,
)
145 changes: 145 additions & 0 deletions vectordb_bench/backend/clients/pgdiskann/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from abc import abstractmethod
from typing import Any, Mapping, Optional, Sequence, TypedDict
from pydantic import BaseModel, SecretStr
from typing_extensions import LiteralString
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"


class PgDiskANNConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
so the names must match exactly psycopg API"""

user: str
password: str
host: str
port: int
dbname: str


class PgDiskANNConfig(DBConfig):
user_name: SecretStr = SecretStr("postgres")
password: SecretStr
host: str = "localhost"
port: int = 5432
db_name: str

def to_dict(self) -> PgDiskANNConfigDict:
user_str = self.user_name.get_secret_value()
pwd_str = self.password.get_secret_value()
return {
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"user": user_str,
"password": pwd_str,
}


class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
create_index_before_load: bool = False
create_index_after_load: bool = True
maintenance_work_mem: Optional[str]
max_parallel_workers: Optional[int]

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"

def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
elif self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"

@abstractmethod
def index_param(self) -> dict:
...

@abstractmethod
def search_param(self) -> dict:
...

@abstractmethod
def session_param(self) -> dict:
...

@staticmethod
def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
options = []
for option_name, value in with_options.items():
if value is not None:
options.append(
{
"option_name": option_name,
"val": str(value),
}
)
return options

@staticmethod
def _optionally_build_set_options(
set_mapping: Mapping[str, Any]
) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
{"parameter": {
"setting_name": setting_name,
"val": str(value),
},
}
)
return session_options


class PgDiskANNImplConfig(PgDiskANNIndexConfig):
index: IndexType = IndexType.DISKANN
max_neighbors: int | None
l_value_ib: int | None
l_value_is: float | None
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"options": {
"max_neighbors": self.max_neighbors,
"l_value_ib": self.l_value_ib,
},
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
}

def search_param(self) -> dict:
return {
"metric": self.parse_metric(),
"metric_fun_op": self.parse_metric_fun_op(),
}

def session_param(self) -> dict:
return {
"diskann.l_value_is": self.l_value_is,
}

_pgdiskann_case_config = {
IndexType.DISKANN: PgDiskANNImplConfig,
}
Loading