diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 4c60013ee..d40004b4f 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -2,10 +2,9 @@ from ..api import DBConfig, DBCaseConfig, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" -INDEX_TYPE = "ivfflat" class PgVectorConfig(DBConfig): - user_name: SecretStr + user_name: SecretStr = "postgres" password: SecretStr url: SecretStr db_name: str @@ -20,8 +19,8 @@ def to_dict(self) -> dict: class PgVectorIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None - lists: int | None = 10 - probes: int | None = 1 + lists: int | None = 1000 + probes: int | None = 10 def parse_metric(self) -> str: if self.metric_type == MetricType.L2: @@ -39,15 +38,12 @@ def parse_metric_fun_str(self) -> str: def index_param(self) -> dict: return { - "postgresql_using" : INDEX_TYPE, - "postgresql_with" : {'lists': self.lists}, - "postgresql_ops": self.parse_metric() + "lists" : self.lists, + "metric" : self.parse_metric() } def search_param(self) -> dict: return { "probes" : self.probes, "metric_fun" : self.parse_metric_fun_str() - } - - \ No newline at end of file + } \ No newline at end of file diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 46baa9a24..53f1000ab 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -4,9 +4,23 @@ import time from contextlib import contextmanager from typing import Any, Type +from functools import wraps from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType +from pgvector.sqlalchemy import Vector from .config import PgVectorConfig, PgVectorIndexConfig +from sqlalchemy import ( + MetaData, + create_engine, + insert, + select, + Index, + Table, + text, + Column, + Float, + Integer +) from sqlalchemy.orm import ( declarative_base, mapped_column, @@ -24,34 +38,34 @@ def __init__( db_case_config: DBCaseConfig, collection_name: str = "PgVectorCollection", drop_old: bool = False, + **kwargs, ): self.db_config = db_config self.case_config = db_case_config self.table_name = collection_name + self.dim = dim self._index_name = "pqvector_index" self._primary_field = "id" self._vector_field = "embedding" # construct basic units - pq_metadata = MetaData() - self.pg_engine = create_engine(**self.db_config) + pg_engine = create_engine(**self.db_config) + Base = declarative_base() + pq_metadata = Base.metadata + pq_metadata.reflect(pg_engine) + # create vector extension - with self.pg_engine as conn: + with pg_engine.connect() as conn: conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) conn.commit() - - self.pg_table = Table( - self.table_name, - pq_metadata, - Column(self._primary_field, Integer, primary_key=True), - Column(self._vector_field, Vector(dim)) - ) + + self.pg_table = self._get_table_schema(pq_metadata) if drop_old and self.table_name in pq_metadata.tables: log.info(f"Pgvector client drop table : {self.table_name}") - self.self.pq_table.drop(bind = engine) - - self._create_table(dim) + # self.pg_table.drop(pg_engine, checkfirst=True) + pq_metadata.drop_all(pg_engine) + self._create_table(dim, pg_engine) @classmethod @@ -70,25 +84,53 @@ def init(self) -> None: >>> self.insert_embeddings() >>> self.search_embedding() """ - self.pq_session = Session(self.pg_engine) + self.pg_engine = create_engine(**self.db_config) + + Base = declarative_base() + pq_metadata = Base.metadata + pq_metadata.reflect(self.pg_engine) + self.pg_session = Session(self.pg_engine) + self.pg_table = self._get_table_schema(pq_metadata) yield - self.pq_session = None - del (self.pq_session) + self.pg_session = None + self.pg_engine = None + del (self.pg_session) + del (self.pg_engine) def ready_to_load(self): pass + def optimize(self): + pass + def ready_to_search(self): pass + + def _get_table_schema(self, pq_metadata): + return Table( + self.table_name, + pq_metadata, + Column(self._primary_field, Integer, primary_key=True), + Column(self._vector_field, Vector(self.dim)), + extend_existing=True + ) - def _create_index(self): - index = Index(self._index_name, self.pq_table.embedding, **self.case_config.index_param()) - index.create(self.pg_engine) + def _create_index(self, pg_engine): + index_param = self.case_config.index_param() + index = Index(self._index_name, self.pg_table.c.embedding, + postgresql_using='ivfflat', + postgresql_with={'lists': index_param["lists"]}, + postgresql_ops={'embedding': index_param["metric"]} + ) + index.drop(pg_engine, checkfirst = True) + index.create(pg_engine) - def _create_table(self, dim : int): + def _create_table(self, dim, pg_engine : int): try: - self.pg_table.create(bind = self.pg_engine, checkfirst = True) - self._create_index() + # create table + self.pg_table.create(bind = pg_engine, checkfirst = True) + # create vec index + self._create_index(pg_engine) except Exception as e: log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}") raise e from None @@ -100,10 +142,10 @@ def insert_embeddings( **kwargs: Any, ) -> (int, Exception): try: - items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(metadata)] - self.pq_session.execute(insert(table), items) - self.pq_session.commit() - return len(items), None + items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))] + self.pg_session.execute(insert(self.pg_table), items) + self.pg_session.commit() + return len(metadata), None except Exception as e: log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}") return 0, e @@ -114,16 +156,16 @@ def search_embedding( k: int = 100, filters: dict | None = None, timeout: int | None = None, - **kwargs: Any, ) -> list[int]: - assert self.pq_table is not None - with self.pg_engine as conn: - conn.execute(text(f'SET ivfflat.probes = {kwargs["probes"]}')) + assert self.pg_table is not None + search_param =self.case_config.search_param() + with self.pg_engine.connect() as conn: + conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}')) conn.commit() - op_fun = getattr(table.c.embedding, kwargs["metric_fun"]) + op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"]) if filters: - res = self.pq_session.scalars(select(self.pq_table.order_by(op_fun(query)).filter(self.pq_table.c.id > filters.get('id')).limit(k))) + res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k)) else: - res = self.pq_session.scalars(select(self.pq_table.order_by(op_fun(query)).limit(k))) + res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k)) return list(res) \ No newline at end of file