Skip to content

Commit

Permalink
refactor: migrate to new pgvecto_rs sdk
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <[email protected]>
  • Loading branch information
cutecutecat committed Jul 30, 2024
1 parent c45876c commit 4853d82
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 147 deletions.
1 change: 1 addition & 0 deletions install/requirements_py3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pinecone-client
weaviate-client
elasticsearch
pgvector
pgvecto_rs[psycopg3]>=0.2.1
sqlalchemy
redis
chromadb
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ all = [
"weaviate-client",
"elasticsearch",
"pgvector",
"pgvecto_rs[psycopg3]>=0.2.1",
"sqlalchemy",
"redis",
"chromadb",
"psycopg2",
"psycopg",
"psycopg-binary",
"opensearch-dsl==2.1.0",
Expand All @@ -71,7 +71,7 @@ pinecone = [ "pinecone-client" ]
weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.1" ]
redis = [ "redis" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
Expand Down
154 changes: 154 additions & 0 deletions vectordb_bench/backend/clients/pgvecto_rs/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Annotated, Optional, Unpack

import click
import os
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
IVFFlatTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from vectordb_bench.backend.clients import DB


class PgVectoRSTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
]
password: Annotated[
str,
click.option(
"--password",
type=str,
help="Postgres database password",
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
show_default="$POSTGRES_PASSWORD",
),
]

host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
db_name: Annotated[
str, click.option("--db-name", type=str, help="Db name", required=True)
]
max_parallel_workers: Annotated[
Optional[int],
click.option(
"--max-parallel-workers",
type=int,
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
required=False,
),
]
quantization_type: Annotated[
str,
click.option(
"--quantization-type",
type=click.Choice(["trivial", "scalar", "product"]),
help="quantization type for vectors",
required=False,
),
]
quantization_ratio: Annotated[
str,
click.option(
"--quantization-ratio",
type=click.Choice(["x4", "x8", "x16", "x32", "x64"]),
help="quantization ratio(for product quantization)",
required=False,
),
]


class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict)
def PgVectoRSFlat(
**parameters: Unpack[PgVectoRSFlatTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSFLATConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSFLATConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
),
**parameters,
)


class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict)
def PgVectoRSIVFFlat(
**parameters: Unpack[PgVectoRSIVFFlatTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSIVFFlatConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
probes=parameters["probes"],
lists=parameters["lists"],
),
**parameters,
)


class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict)
def PgVectoRSHNSW(
**parameters: Unpack[PgVectoRSHNSWTypedDict],
):
from .config import PgVectoRSConfig, PgVectoRSHNSWConfig

run(
db=DB.PgVectoRS,
db_config=PgVectoRSConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgVectoRSHNSWConfig(
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
quantization_ratio=parameters["quantization_ratio"],
m=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_search=parameters["ef_search"],
),
**parameters,
)
Loading

0 comments on commit 4853d82

Please sign in to comment.