From e0119a6b958102eab0ecd980ec24fd40b082e585 Mon Sep 17 00:00:00 2001 From: "zihengsjtu@gmail.com" Date: Fri, 19 Jul 2024 11:28:50 +0800 Subject: [PATCH] finish alloy --- .../backend/clients/alloydb/alloy.py | 245 ++++++++++++++++++ .../backend/clients/alloydb/config.py | 215 +++++++++++++++ .../frontend/config/dbCaseConfigs.py | 4 + 3 files changed, 464 insertions(+) create mode 100644 vectordb_bench/backend/clients/alloydb/alloy.py create mode 100644 vectordb_bench/backend/clients/alloydb/config.py diff --git a/vectordb_bench/backend/clients/alloydb/alloy.py b/vectordb_bench/backend/clients/alloydb/alloy.py new file mode 100644 index 000000000..d3bffe2b7 --- /dev/null +++ b/vectordb_bench/backend/clients/alloydb/alloy.py @@ -0,0 +1,245 @@ + +import numpy as np +from pgvector.psycopg2 import register_vector +from psycopg2.extras import execute_values +import logging +import pprint +from contextlib import contextmanager +from typing import Any, Generator, Optional, Tuple, Sequence +import psycopg2 +import numpy as np + +from ..api import VectorDB +from .config import PgVectorConfigDict, PgVectorIndexConfig + +log = logging.getLogger(__name__) + + + +class alloyDB(VectorDB): + + def __init__( + self, + dim: int, + db_config: PgVectorConfigDict, + db_case_config: PgVectorIndexConfig, + collection_name: str = "pg_vector_collection", + drop_old: bool = False, + **kwargs, + ): + self.name = "AlloyDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + + self._index_name = "hnsw" + self._primary_field = "id" + self._vector_field = "embedding" + + # construct basic units + self.conn, self.cursor = self._create_connection(**self.db_config) + + # create vector extension + self.conn.commit() + print(self.conn) + + if drop_old: + # self.pg_table.drop(pg_engine, checkfirst=True) + self._drop_index() + self._drop_table() + self._create_table(dim) + if self.case_config.create_index_before_load: + self._create_index() + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + + @staticmethod + def _create_connection(**kwargs): + '''No problem''' + conn = psycopg2.connect( + host = kwargs['host'], + port = kwargs['port'], + user = kwargs['user'], + password = kwargs['password'] + ) + conn.autocommit = False + cursor = conn.cursor() + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;") + conn.commit() + register_vector(conn) + + #cursor.execute(';') + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + return conn, cursor + + + + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + self.cursor.execute( + f''' + DROP TABLE IF EXISTS public.{self.table_name} + ''' + ) + self.conn.commit() + + + def ready_to_load(self): + pass + + def optimize(self): + self._post_insert() + + def _post_insert(self): + log.info(f"{self.name} post insert before optimize") + if self.case_config.create_index_after_load: + self._drop_index() + self._create_index() + + + + def _drop_index(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop index : {self._index_name}") + + drop_index_sql = f''' + DROP INDEX IF EXISTS {self._index_name} + ''' + + self.cursor.execute(drop_index_sql) + self.conn.commit() + + + @contextmanager + def init(self) -> Generator[None, None, None]: + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + + self.conn, self.cursor = self._create_connection(**self.db_config) + + # index configuration may have commands defined that we should set during each client session + + session_options: Sequence[dict[str, Any]] = self.case_config.session_param()["session_options"] + + + if len(session_options) > 0: + for setting in session_options: + command = f'''SET {setting['parameter']['setting_name']} = {setting['parameter']['val']}''' + + self.cursor.execute(command) + self.conn.commit() + + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + def _set_parallel_index_build_param(self): + pass + + + + def _create_index(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + index_create_sql = f'''CREATE INDEX IF NOT EXISTS {self._index_name} ON public.{self.table_name} USING {index_param["index_type"]} (embedding {index_param["metric"]})''' + + self.cursor.execute(index_create_sql) + self.conn.commit() + + + + + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + # create table + self.cursor.execute( + f''' + CREATE TABLE IF NOT EXISTS public.{self.table_name} (id BIGINT PRIMARY KEY, embedding vector({self.dim})); + ''' + ) + self.cursor.execute( + f''' + ALTER TABLE public.{self.table_name} ALTER COLUMN embedding SET STORAGE PLAIN; + ''' + ) + self.conn.commit() + except Exception as e: + raise e from None + + + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + for i in range(len(metadata_arr)): + meta = metadata[i] + arr = np.array(embeddings_arr[i]) + self.cursor.execute( + f'insert into {self.table_name} (id, embedding) values (%s, %s);', (meta, arr) + ) + self.conn.commit() + + return len(metadata), None + except Exception as e: + return 0, e + + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + arr = np.array(query) + try: + self.cursor.execute(f''' + SELECT id FROM public.{self.table_name} ORDER BY embedding <=> %s LIMIT {k}; + ''', (arr,)) + except Exception as e: + raise e from None + + result = self.cursor.fetchall() + return [int(i[0]) for i in result] + + diff --git a/vectordb_bench/backend/clients/alloydb/config.py b/vectordb_bench/backend/clients/alloydb/config.py new file mode 100644 index 000000000..7e848adb2 --- /dev/null +++ b/vectordb_bench/backend/clients/alloydb/config.py @@ -0,0 +1,215 @@ +from abc import abstractmethod +from typing import Any, Mapping, Optional, Sequence, TypedDict +from pydantic import BaseModel, SecretStr +from typing_extensions import LiteralString +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + +POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" + + +class PgVectorConfigDict(TypedDict): + """These keys will be directly used as kwargs in psycopg connection string, + so the names must match exactly psycopg API""" + + user: str + password: str + host: str + port: int + dbname: str + + +class PgVectorConfig(DBConfig): + user_name: SecretStr = SecretStr("postgres") + password: SecretStr + host: str = "10.100.0.10" + port: int = 5432 + db_name: str + + def to_dict(self) -> PgVectorConfigDict: + user_str = self.user_name.get_secret_value() + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": user_str, + "password": pwd_str, + } + + +class PgVectorIndexParam(TypedDict): + metric: str + index_type: str + index_creation_with_options: Sequence[dict[str, Any]] + maintenance_work_mem: Optional[str] + max_parallel_workers: Optional[int] + + +class PgVectorSearchParam(TypedDict): + metric_fun_op: LiteralString + + +class PgVectorSessionCommands(TypedDict): + session_options: Sequence[dict[str, Any]] + + +class PgVectorIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + create_index_before_load: bool = False + create_index_after_load: bool = True + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "vector_l2_ops" + elif self.metric_type == MetricType.IP: + return "vector_ip_ops" + return "vector_cosine_ops" + + def parse_metric_fun_op(self) -> LiteralString: + if self.metric_type == MetricType.L2: + return "<->" + elif self.metric_type == MetricType.IP: + return "<#>" + return "<=>" + + def parse_metric_fun_str(self) -> str: + if self.metric_type == MetricType.L2: + return "l2_distance" + elif self.metric_type == MetricType.IP: + return "max_inner_product" + return "cosine_distance" + + @abstractmethod + def index_param(self) -> PgVectorIndexParam: + ... + + @abstractmethod + def search_param(self) -> PgVectorSearchParam: + ... + + @abstractmethod + def session_param(self) -> PgVectorSessionCommands: + ... + + @staticmethod + def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" + options = [] + for option_name, value in with_options.items(): + if value is not None: + options.append( + { + "option_name": option_name, + "val": str(value), + } + ) + return options + + @staticmethod + def _optionally_build_set_options( + set_mapping: Mapping[str, Any] + ) -> Sequence[dict[str, Any]]: + """Walk through options, creating 'SET 'key1 = "value1";' commands""" + session_options = [] + for setting_name, value in set_mapping.items(): + if value: + session_options.append( + {"parameter": { + "setting_name": setting_name, + "val": str(value), + }, + } + ) + return session_options + + +class PgVectorIVFFlatConfig(PgVectorIndexConfig): + """ + An IVFFlat index divides vectors into lists, and then searches a subset of those lists that are + closest to the query vector. It has faster build times and uses less memory than HNSW, + but has lower query performance (in terms of speed-recall tradeoff). + + Three keys to achieving good recall are: + + Create the index after the table has some data + Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for + over 1M rows. + When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - + a good place to start is sqrt(lists) + """ + + lists: int | None + probes: int | None + index: IndexType = IndexType.ES_IVFFlat + maintenance_work_mem: Optional[str] = None + max_parallel_workers: Optional[int] = None + + def index_param(self) -> PgVectorIndexParam: + index_parameters = {"lists": self.lists} + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "index_creation_with_options": self._optionally_build_with_options( + index_parameters + ), + "maintenance_work_mem": self.maintenance_work_mem, + "max_parallel_workers": self.max_parallel_workers, + } + + def search_param(self) -> PgVectorSearchParam: + return { + "metric_fun_op": self.parse_metric_fun_op(), + } + + def session_param(self) -> PgVectorSessionCommands: + session_parameters = {"ivfflat.probes": self.probes} + return { + "session_options": self._optionally_build_set_options(session_parameters) + } + + +class PgVectorHNSWConfig(PgVectorIndexConfig): + """ + An HNSW index creates a multilayer graph. It has better query performance than IVFFlat (in terms of + speed-recall tradeoff), but has slower build times and uses more memory. Also, an index can be + created without any data in the table since there isn't a training step like IVFFlat. + """ + + m: int | None # DETAIL: Valid values are between "2" and "100". + ef_construction: ( + int | None + ) # ef_construction must be greater than or equal to 2 * m + ef_search: int | None + index: IndexType = IndexType.ES_HNSW + maintenance_work_mem: Optional[str] = None + max_parallel_workers: Optional[int] = None + + def index_param(self) -> PgVectorIndexParam: + index_parameters = {"m": self.m, "ef_construction": self.ef_construction} + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "index_creation_with_options": self._optionally_build_with_options( + index_parameters + ), + "maintenance_work_mem": self.maintenance_work_mem, + "max_parallel_workers": self.max_parallel_workers, + } + + def search_param(self) -> PgVectorSearchParam: + return { + "metric_fun_op": self.parse_metric_fun_op(), + } + + def session_param(self) -> PgVectorSessionCommands: + session_parameters = {"hnsw.ef_search": self.ef_search} + return { + "session_options": self._optionally_build_set_options(session_parameters) + } + + +_pgvector_case_config = { + IndexType.HNSW: PgVectorHNSWConfig, + IndexType.ES_HNSW: PgVectorHNSWConfig, + IndexType.IVFFlat: PgVectorIVFFlatConfig, +} \ No newline at end of file diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index ce8a3a4ae..21eec9c2b 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -753,4 +753,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: PgVectoRSLoadingConfig, CaseLabel.Performance: PgVectoRSPerformanceConfig, }, + DB.Alloy: { + CaseLabel.Load: PgVectorLoadingConfig, + CaseLabel.Performance: PgVectorPerformanceConfig, + } }