From 4853d82024649cb8ce36024c5b9521b2cdad37d1 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 29 Jul 2024 17:07:42 +0800 Subject: [PATCH] refactor: migrate to new pgvecto_rs sdk Signed-off-by: cutecutecat --- install/requirements_py3.11.txt | 1 + pyproject.toml | 4 +- .../backend/clients/pgvecto_rs/cli.py | 154 +++++++++++++ .../backend/clients/pgvecto_rs/config.py | 181 +++++++++------ .../backend/clients/pgvecto_rs/pgvecto_rs.py | 218 +++++++++++++----- vectordb_bench/cli/vectordbbench.py | 3 + .../frontend/config/dbCaseConfigs.py | 59 ++++- vectordb_bench/models.py | 7 +- 8 files changed, 480 insertions(+), 147 deletions(-) create mode 100644 vectordb_bench/backend/clients/pgvecto_rs/cli.py diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index c55601934..c3a3bbbda 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -5,6 +5,7 @@ pinecone-client weaviate-client elasticsearch pgvector +pgvecto_rs[psycopg3]>=0.2.1 sqlalchemy redis chromadb diff --git a/pyproject.toml b/pyproject.toml index 2a812179c..df9d2ee0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,10 +56,10 @@ all = [ "weaviate-client", "elasticsearch", "pgvector", + "pgvecto_rs[psycopg3]>=0.2.1", "sqlalchemy", "redis", "chromadb", - "psycopg2", "psycopg", "psycopg-binary", "opensearch-dsl==2.1.0", @@ -71,7 +71,7 @@ pinecone = [ "pinecone-client" ] weaviate = [ "weaviate-client" ] elastic = [ "elasticsearch" ] pgvector = [ "psycopg", "psycopg-binary", "pgvector" ] -pgvecto_rs = [ "psycopg2" ] +pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.1" ] redis = [ "redis" ] chromadb = [ "chromadb" ] awsopensearch = [ "awsopensearch" ] diff --git a/vectordb_bench/backend/clients/pgvecto_rs/cli.py b/vectordb_bench/backend/clients/pgvecto_rs/cli.py new file mode 100644 index 000000000..10dbff556 --- /dev/null +++ b/vectordb_bench/backend/clients/pgvecto_rs/cli.py @@ -0,0 +1,154 @@ +from typing import Annotated, Optional, Unpack + +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + IVFFlatTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class PgVectoRSTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--user-name", type=str, help="Db username", required=True) + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), + ] + + host: Annotated[ + str, click.option("--host", type=str, help="Db host", required=True) + ] + db_name: Annotated[ + str, click.option("--db-name", type=str, help="Db name", required=True) + ] + max_parallel_workers: Annotated[ + Optional[int], + click.option( + "--max-parallel-workers", + type=int, + help="Sets the maximum number of parallel processes per maintenance operation (index creation)", + required=False, + ), + ] + quantization_type: Annotated[ + str, + click.option( + "--quantization-type", + type=click.Choice(["trivial", "scalar", "product"]), + help="quantization type for vectors", + required=False, + ), + ] + quantization_ratio: Annotated[ + str, + click.option( + "--quantization-ratio", + type=click.Choice(["x4", "x8", "x16", "x32", "x64"]), + help="quantization ratio(for product quantization)", + required=False, + ), + ] + + +class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict) +def PgVectoRSFlat( + **parameters: Unpack[PgVectoRSFlatTypedDict], +): + from .config import PgVectoRSConfig, PgVectoRSFLATConfig + + run( + db=DB.PgVectoRS, + db_config=PgVectoRSConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgVectoRSFLATConfig( + max_parallel_workers=parameters["max_parallel_workers"], + quantization_type=parameters["quantization_type"], + quantization_ratio=parameters["quantization_ratio"], + ), + **parameters, + ) + + +class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict) +def PgVectoRSIVFFlat( + **parameters: Unpack[PgVectoRSIVFFlatTypedDict], +): + from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig + + run( + db=DB.PgVectoRS, + db_config=PgVectoRSConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgVectoRSIVFFlatConfig( + max_parallel_workers=parameters["max_parallel_workers"], + quantization_type=parameters["quantization_type"], + quantization_ratio=parameters["quantization_ratio"], + probes=parameters["probes"], + lists=parameters["lists"], + ), + **parameters, + ) + + +class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict) +def PgVectoRSHNSW( + **parameters: Unpack[PgVectoRSHNSWTypedDict], +): + from .config import PgVectoRSConfig, PgVectoRSHNSWConfig + + run( + db=DB.PgVectoRS, + db_config=PgVectoRSConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgVectoRSHNSWConfig( + max_parallel_workers=parameters["max_parallel_workers"], + quantization_type=parameters["quantization_type"], + quantization_ratio=parameters["quantization_ratio"], + m=parameters["m"], + ef_construction=parameters["ef_construction"], + ef_search=parameters["ef_search"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/pgvecto_rs/config.py b/vectordb_bench/backend/clients/pgvecto_rs/config.py index 73c41c239..c671a236c 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/config.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/config.py @@ -1,30 +1,53 @@ -from typing import Literal +from abc import abstractmethod +from typing import TypedDict + from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization +from pgvecto_rs.types.index import QuantizationType, QuantizationRatio + +from ..api import DBConfig, DBCaseConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" +class PgVectorRSConfigDict(TypedDict): + """These keys will be directly used as kwargs in psycopg connection string, + so the names must match exactly psycopg API""" + + user: str + password: str + host: str + port: int + dbname: str + + class PgVectoRSConfig(DBConfig): - user_name: SecretStr = "postgres" + user_name: str = "postgres" password: SecretStr host: str = "localhost" port: int = 5432 db_name: str def to_dict(self) -> dict: - user_str = self.user_name.get_secret_value() + user_str = self.user_name pwd_str = self.password.get_secret_value() return { "host": self.host, "port": self.port, "dbname": self.db_name, "user": user_str, - "password": pwd_str + "password": pwd_str, } + class PgVectoRSIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None + create_index_before_load: bool = False + create_index_after_load: bool = True + + max_parallel_workers: int | None = None + quantization_type: QuantizationType | None = None + quantization_ratio: QuantizationRatio | None = None def parse_metric(self) -> str: if self.metric_type == MetricType.L2: @@ -40,88 +63,100 @@ def parse_metric_fun_op(self) -> str: return "<#>" return "<=>" -class PgVectoRSQuantConfig(PgVectoRSIndexConfig): - quantizationType: Literal["trivial", "scalar", "product"] - quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"] - - def parse_quantization(self) -> str: - if self.quantizationType == "trivial": - return "quantization = { trivial = { } }" - elif self.quantizationType == "scalar": - return "quantization = { scalar = { } }" - else: - return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}' - + def search_param(self) -> dict: + return { + "metric_fun_op": self.parse_metric_fun_op(), + } -class HNSWConfig(PgVectoRSQuantConfig): - M: int - efConstruction: int - index: IndexType = IndexType.HNSW + @abstractmethod + def index_param(self) -> dict[str, str]: ... - def index_param(self) -> dict: - options = f""" -[indexing.hnsw] -m = {self.M} -ef_construction = {self.efConstruction} -{self.parse_quantization()} -""" - return {"options": options, "metric": self.parse_metric()} + @abstractmethod + def session_param(self) -> dict[str, str | int]: ... - def search_param(self) -> dict: - return {"metrics_op": self.parse_metric_fun_op()} +class PgVectoRSHNSWConfig(PgVectoRSIndexConfig): + index: IndexType = IndexType.HNSW + m: int | None = None + ef_search: int | None + ef_construction: int | None = None -class IVFFlatConfig(PgVectoRSQuantConfig): - nlist: int - nprobe: int | None = None + def index_param(self) -> dict[str, str]: + if self.quantization_type is None: + quantization = None + else: + quantization = Quantization( + typ=self.quantization_type, ratio=self.quantization_ratio + ) + + option = IndexOption( + index=Hnsw( + m=self.m, + ef_construction=self.ef_construction, + quantization=quantization, + ), + threads=self.max_parallel_workers, + ) + return {"options": option.dumps(), "metric": self.parse_metric()} + + def session_param(self) -> dict[str, str | int]: + session_parameters = {} + if self.ef_search is not None: + session_parameters["vectors.hnsw_ef_search"] = str(self.ef_search) + return session_parameters + + +class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig): index: IndexType = IndexType.IVFFlat + probes: int | None + lists: int | None + + def index_param(self) -> dict[str, str]: + if self.quantization_type is None: + quantization = None + else: + quantization = Quantization( + typ=self.quantization_type, ratio=self.quantization_ratio + ) - def index_param(self) -> dict: - options = f""" -[indexing.ivf] -nlist = {self.nlist} -nsample = {self.nprobe if self.nprobe else 10} -{self.parse_quantization()} -""" - return {"options": options, "metric": self.parse_metric()} + option = IndexOption( + index=Ivf(nlist=self.lists, quantization=quantization), + threads=self.max_parallel_workers, + ) + return {"options": option.dumps(), "metric": self.parse_metric()} - def search_param(self) -> dict: - return {"metrics_op": self.parse_metric_fun_op()} - -class IVFFlatSQ8Config(PgVectoRSIndexConfig): - nlist: int - nprobe: int | None = None - index: IndexType = IndexType.IVFSQ8 - - def index_param(self) -> dict: - options = f""" -[indexing.ivf] -nlist = {self.nlist} -nsample = {self.nprobe if self.nprobe else 10} -quantization = {{ scalar = {{ }} }} -""" - return {"options": options, "metric": self.parse_metric()} + def session_param(self) -> dict[str, str | int]: + session_parameters = {} + if self.probes is not None: + session_parameters["vectors.ivf_nprobe"] = str(self.probes) + return session_parameters - def search_param(self) -> dict: - return {"metrics_op": self.parse_metric_fun_op()} -class FLATConfig(PgVectoRSQuantConfig): +class PgVectoRSFLATConfig(PgVectoRSIndexConfig): index: IndexType = IndexType.Flat - def index_param(self) -> dict: - options = f""" -[indexing.flat] -{self.parse_quantization()} -""" - return {"options": options, "metric": self.parse_metric()} + def index_param(self) -> dict[str, str]: + if self.quantization_type is None: + quantization = None + else: + quantization = Quantization( + typ=self.quantization_type, ratio=self.quantization_ratio + ) - def search_param(self) -> dict: - return {"metrics_op": self.parse_metric_fun_op()} + option = IndexOption( + index=Flat( + quantization=quantization, + ), + threads=self.max_parallel_workers, + ) + return {"options": option.dumps(), "metric": self.parse_metric()} + + def session_param(self) -> dict[str, str | int]: + return {} _pgvecto_rs_case_config = { - IndexType.HNSW: HNSWConfig, - IndexType.IVFFlat: IVFFlatConfig, - IndexType.IVFSQ8: IVFFlatSQ8Config, - IndexType.Flat: FLATConfig, + IndexType.HNSW: PgVectoRSHNSWConfig, + IndexType.IVFFlat: PgVectoRSIVFFlatConfig, + IndexType.Flat: PgVectoRSFLATConfig, } diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index 22caa43e6..bc042cc57 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -1,73 +1,138 @@ """Wrapper around the Pgvecto.rs vector database over VectorDB""" -import io import logging +import pprint from contextlib import contextmanager -from typing import Any -import pandas as pd -import psycopg2 -import psycopg2.extras +from typing import Any, Generator, Optional, Tuple -from ..api import VectorDB, DBCaseConfig +import numpy as np +import psycopg +from psycopg import Connection, Cursor, sql +from pgvecto_rs.psycopg import register_vector + +from ..api import VectorDB +from .config import PgVectoRSConfig, PgVectoRSIndexConfig log = logging.getLogger(__name__) + class PgVectoRS(VectorDB): - """Use SQLAlchemy instructions""" + """Use psycopg instructions""" + + conn: psycopg.Connection[Any] | None = None + cursor: psycopg.Cursor[Any] | None = None + _unfiltered_search: sql.Composed + _filtered_search: sql.Composed def __init__( self, dim: int, - db_config: dict, - db_case_config: DBCaseConfig, - collection_name: str = "PgVectorCollection", + db_config: PgVectoRSConfig, + db_case_config: PgVectoRSIndexConfig, + collection_name: str = "PgVectoRSCollection", drop_old: bool = False, **kwargs, ): + + self.name = "PgVectorRS" self.db_config = db_config self.case_config = db_case_config self.table_name = collection_name self.dim = dim - self._index_name = "pqvector_index" + self._index_name = "pgvectors_index" self._primary_field = "id" self._vector_field = "embedding" # construct basic units - self.conn = psycopg2.connect(**self.db_config) - self.conn.autocommit = False - self.cursor = self.conn.cursor() + self.conn, self.cursor = self._create_connection(**self.db_config) - # create vector extension - self.cursor.execute("CREATE EXTENSION IF NOT EXISTS vectors") - self.conn.commit() + log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}") + if not any( + ( + self.case_config.create_index_before_load, + self.case_config.create_index_after_load, + ) + ): + err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + log.error(err) + raise RuntimeError( + f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + ) if drop_old: log.info(f"Pgvecto.rs client drop table : {self.table_name}") self._drop_index() self._drop_table() self._create_table(dim) - self._create_index() + if self.case_config.create_index_before_load: + self._create_index() self.cursor.close() self.conn.close() self.cursor = None self.conn = None + @staticmethod + def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + conn = psycopg.connect(**kwargs) + + # create vector extension + conn.execute("CREATE EXTENSION IF NOT EXISTS vectors") + conn.commit() + register_vector(conn) + + conn.autocommit = False + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + @contextmanager - def init(self) -> None: + def init(self) -> Generator[None, None, None]: """ Examples: >>> with self.init(): >>> self.insert_embeddings() >>> self.search_embedding() """ - self.conn = psycopg2.connect(**self.db_config) - self.conn.autocommit = False - self.cursor = self.conn.cursor() - self.cursor.execute('SET search_path = "$user", public, vectors') + + self.conn, self.cursor = self._create_connection(**self.db_config) + + # index configuration may have commands defined that we should set during each client session + session_options = self.case_config.session_param() + + for key, val in session_options.items(): + command = sql.SQL("SET {setting_name} " + "= {val};").format( + setting_name=sql.Identifier(key), + val=val, + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) self.conn.commit() + self._filtered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " + ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) + + self._unfiltered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} ORDER BY embedding " + ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) + try: yield finally: @@ -79,42 +144,65 @@ def init(self) -> None: def _drop_table(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop table : {self.table_name}") - self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"') + self.cursor.execute( + sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( + table_name=sql.Identifier(self.table_name) + ) + ) self.conn.commit() def ready_to_load(self): pass def optimize(self): - pass + self._post_insert() - def ready_to_search(self): - pass + def _post_insert(self): + log.info(f"{self.name} post insert before optimize") + if self.case_config.create_index_after_load: + self._drop_index() + self._create_index() def _drop_index(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop index : {self._index_name}") - self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"') + drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( + index_name=sql.Identifier(self._index_name) + ) + log.debug(drop_index_sql.as_string(self.cursor)) + self.cursor.execute(drop_index_sql) self.conn.commit() def _create_index(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client create index : {self._index_name}") index_param = self.case_config.index_param() + index_create_sql = sql.SQL( + """ + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING vectors (embedding {embedding_metric}) WITH (options = {index_options}) + """ + ).format( + index_name=sql.Identifier(self._index_name), + table_name=sql.Identifier(self.table_name), + embedding_metric=sql.Identifier(index_param["metric"]), + index_options=index_param["options"], + ) try: - # create table - self.cursor.execute( - f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" \ - USING vectors (embedding {index_param["metric"]}) WITH (options = $${index_param["options"]}$$);' - ) + log.debug(index_create_sql.as_string(self.cursor)) + self.cursor.execute(index_create_sql) self.conn.commit() except Exception as e: log.warning( - f"Failed to create pgvecto.rs table: {self.table_name} error: {e}" + f"Failed to create pgvecto.rs index {self._index_name} \ + at table {self.table_name} error: {e}" ) raise e from None @@ -122,12 +210,18 @@ def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + table_create_sql = sql.SQL( + """ + CREATE TABLE IF NOT EXISTS public.{table_name} + (id BIGINT PRIMARY KEY, embedding vector({dim})) + """ + ).format( + table_name=sql.Identifier(self.table_name), + dim=dim, + ) try: # create table - self.cursor.execute( - f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \ - (id Integer PRIMARY KEY, embedding vector({dim}));' - ) + self.cursor.execute(table_create_sql) self.conn.commit() except Exception as e: log.warning( @@ -140,7 +234,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> (int, Exception): + ) -> Tuple[int, Optional[Exception]]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -148,19 +242,27 @@ def insert_embeddings( assert self.cursor is not None, "Cursor is not initialized" try: - items = { - "id": metadata, - "embedding": embeddings - } - df = pd.DataFrame(items) - csv_buffer = io.StringIO() - df.to_csv(csv_buffer, index=False, header=False) - csv_buffer.seek(0) - self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer) + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() + + if kwargs.get("last_batch"): + self._post_insert() + return len(metadata), None except Exception as e: - log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}") + log.warning( + f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}" + ) return 0, e def search_embedding( @@ -173,20 +275,18 @@ def search_embedding( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" - search_param = self.case_config.search_param() + q = np.asarray(query) if filters: + log.debug(self._filtered_search.as_string(self.cursor)) gt = filters.get("id") - self.cursor.execute( - f"SELECT id FROM (SELECT * FROM public.\"{self.table_name}\" ORDER BY embedding \ - {search_param['metrics_op']} '{query}' LIMIT {k}) AS X WHERE id > {gt} ;" + result = self.cursor.execute( + self._filtered_search, (gt, q, k), prepare=True, binary=True ) else: - self.cursor.execute( - f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding \ - {search_param['metrics_op']} '{query}' LIMIT {k};" + log.debug(self._unfiltered_search.as_string(self.cursor)) + result = self.cursor.execute( + self._unfiltered_search, (q, k), prepare=True, binary=True ) - self.conn.commit() - result = self.cursor.fetchall() - return [int(i[0]) for i in result] + return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 0b619bbec..5bf260600 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,4 +1,5 @@ from ..backend.clients.pgvector.cli import PgVectorHNSW +from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat from ..backend.clients.redis.cli import Redis from ..backend.clients.test.cli import Test from ..backend.clients.weaviate_cloud.cli import Weaviate @@ -10,6 +11,8 @@ from .cli import cli cli.add_command(PgVectorHNSW) +cli.add_command(PgVectoRSHNSW) +cli.add_command(PgVectoRSIVFFlat) cli.add_command(Redis) cli.add_command(Weaviate) cli.add_command(Test) diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index ce8a3a4ae..687f1efbf 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -190,6 +190,19 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_PgVectoRS = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + IndexType.Flat.value, + ], + }, +) + CaseConfigParamInput_M = CaseConfigInput( label=CaseConfigParamType.M, inputType=InputType.Number, @@ -272,14 +285,26 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_EFConstruction_PgVectoRS = CaseConfigInput( - label=CaseConfigParamType.EFConstruction, + label=CaseConfigParamType.ef_construction, inputType=InputType.Number, inputConfig={ - "min": 8, - "max": 512, - "value": 360, + "min": 10, + "max": 2000, + "value": 300, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_PgVectoRS = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 65535, + "value": 100, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) @@ -598,6 +623,7 @@ class CaseConfigInput(BaseModel): == IndexType.HNSW.value, ) + CaseConfigParamInput_QuantizationType_PgVectoRS = CaseConfigInput( label=CaseConfigParamType.quantizationType, inputType=InputType.Option, @@ -626,6 +652,18 @@ class CaseConfigInput(BaseModel): ], ) +CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput( + label=CaseConfigParamType.max_parallel_workers, + displayLabel="Max parallel workers", + inputHelp="Recommended value: (cpu cores - 1). This will set the parameters: [optimizing.optimizing_threads]", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1024, + "value": 16, + }, +) + CaseConfigParamInput_ZillizLevel = CaseConfigInput( label=CaseConfigParamType.level, inputType=InputType.Number, @@ -707,22 +745,25 @@ class CaseConfigInput(BaseModel): ] PgVectoRSLoadingConfig = [ - CaseConfigParamInput_IndexType, - CaseConfigParamInput_M, + CaseConfigParamInput_IndexType_PgVectoRS, + CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVectoRS, CaseConfigParamInput_Nlist, CaseConfigParamInput_QuantizationType_PgVectoRS, CaseConfigParamInput_QuantizationRatio_PgVectoRS, + CaseConfigParamInput_max_parallel_workers_PgVectorRS, ] PgVectoRSPerformanceConfig = [ - CaseConfigParamInput_IndexType, - CaseConfigParamInput_M, + CaseConfigParamInput_IndexType_PgVectoRS, + CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVectoRS, + CaseConfigParamInput_EFSearch_PgVectoRS, CaseConfigParamInput_Nlist, CaseConfigParamInput_Nprobe, CaseConfigParamInput_QuantizationType_PgVectoRS, CaseConfigParamInput_QuantizationRatio_PgVectoRS, + CaseConfigParamInput_max_parallel_workers_PgVectorRS, ] ZillizCloudPerformanceConfig = [ diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 56034796e..73845cf5d 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -2,7 +2,7 @@ import pathlib from datetime import date from enum import Enum, StrEnum, auto -from typing import List, Self, Sequence, Set +from typing import List, Self import ujson @@ -10,7 +10,6 @@ DB, DBConfig, DBCaseConfig, - IndexType, ) from .backend.cases import CaseType from .base import BaseModel @@ -46,8 +45,8 @@ class CaseConfigParamType(Enum): numCandidates = "num_candidates" lists = "lists" probes = "probes" - quantizationType = "quantizationType" - quantizationRatio = "quantizationRatio" + quantizationType = "quantization_type" + quantizationRatio = "quantization_ratio" m = "m" nbits = "nbits" intermediate_graph_degree = "intermediate_graph_degree"