Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: migrate to new pgvecto_rs sdk #353

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions install/requirements_py3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pinecone-client
weaviate-client
elasticsearch
pgvector
pgvecto_rs[psycopg3]>=0.2.1
sqlalchemy
redis
chromadb
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ all = [
"weaviate-client",
"elasticsearch",
"pgvector",
"pgvecto_rs[psycopg3]>=0.2.1",
"sqlalchemy",
"redis",
"chromadb",
"psycopg2",
"psycopg",
"psycopg-binary",
"opensearch-dsl==2.1.0",
Expand All @@ -71,7 +71,7 @@ pinecone = [ "pinecone-client" ]
weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.1" ]
redis = [ "redis" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
Expand Down
154 changes: 154 additions & 0 deletions vectordb_bench/backend/clients/pgvecto_rs/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Annotated, Optional, Unpack

import click
import os
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
IVFFlatTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from vectordb_bench.backend.clients import DB


class PgVectoRSTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
]
password: Annotated[
str,
click.option(
"--password",
type=str,
help="Postgres database password",
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
show_default="$POSTGRES_PASSWORD",
),
]

host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
db_name: Annotated[
str, click.option("--db-name", type=str, help="Db name", required=True)
]
max_parallel_workers: Annotated[
Optional[int],
click.option(
"--max-parallel-workers",
type=int,
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
required=False,
),
]
quantization_type: Annotated[
str,
click.option(
"--quantization-type",
type=click.Choice(["trivial", "scalar", "product"]),
help="quantization type for vectors",
required=False,
),
]
quantization_ratio: Annotated[
str,
click.option(
"--quantization-ratio",
type=click.Choice(["x4", "x8", "x16", "x32", "x64"]),
help="quantization ratio(for product quantization)",
required=False,
),
]


class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict)
def PgVectoRSFlat(
**parameters: Unpack[PgVectoRSFlatTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSFLATConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSFLATConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
),
**parameters,
)


class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict)
def PgVectoRSIVFFlat(
**parameters: Unpack[PgVectoRSIVFFlatTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSIVFFlatConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
probes=parameters["probes"],
lists=parameters["lists"],
),
**parameters,
)


class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict)
def PgVectoRSHNSW(
**parameters: Unpack[PgVectoRSHNSWTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSHNSWConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSHNSWConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
m=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_search=parameters["ef_search"],
),
**parameters,
)
Loading