Skip to content

Commit

Permalink
refactor pgvector code
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored and XuanYang-cn committed Jul 12, 2023
1 parent 934f153 commit e65413e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 43 deletions.
16 changes: 6 additions & 10 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from ..api import DBConfig, DBCaseConfig, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
INDEX_TYPE = "ivfflat"

class PgVectorConfig(DBConfig):
user_name: SecretStr
user_name: SecretStr = "postgres"
password: SecretStr
url: SecretStr
db_name: str
Expand All @@ -20,8 +19,8 @@ def to_dict(self) -> dict:

class PgVectorIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
lists: int | None = 10
probes: int | None = 1
lists: int | None = 1000
probes: int | None = 10

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
Expand All @@ -39,15 +38,12 @@ def parse_metric_fun_str(self) -> str:

def index_param(self) -> dict:
return {
"postgresql_using" : INDEX_TYPE,
"postgresql_with" : {'lists': self.lists},
"postgresql_ops": self.parse_metric()
"lists" : self.lists,
"metric" : self.parse_metric()
}

def search_param(self) -> dict:
return {
"probes" : self.probes,
"metric_fun" : self.parse_metric_fun_str()
}


}
108 changes: 75 additions & 33 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@
import time
from contextlib import contextmanager
from typing import Any, Type
from functools import wraps

from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
from pgvector.sqlalchemy import Vector
from .config import PgVectorConfig, PgVectorIndexConfig
from sqlalchemy import (
MetaData,
create_engine,
insert,
select,
Index,
Table,
text,
Column,
Float,
Integer
)
from sqlalchemy.orm import (
declarative_base,
mapped_column,
Expand All @@ -24,34 +38,34 @@ def __init__(
db_case_config: DBCaseConfig,
collection_name: str = "PgVectorCollection",
drop_old: bool = False,
**kwargs,
):
self.db_config = db_config
self.case_config = db_case_config
self.table_name = collection_name
self.dim = dim

self._index_name = "pqvector_index"
self._primary_field = "id"
self._vector_field = "embedding"

# construct basic units
pq_metadata = MetaData()
self.pg_engine = create_engine(**self.db_config)
pg_engine = create_engine(**self.db_config)
Base = declarative_base()
pq_metadata = Base.metadata
pq_metadata.reflect(pg_engine)

# create vector extension
with self.pg_engine as conn:
with pg_engine.connect() as conn:
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
conn.commit()

self.pg_table = Table(
self.table_name,
pq_metadata,
Column(self._primary_field, Integer, primary_key=True),
Column(self._vector_field, Vector(dim))
)

self.pg_table = self._get_table_schema(pq_metadata)
if drop_old and self.table_name in pq_metadata.tables:
log.info(f"Pgvector client drop table : {self.table_name}")
self.self.pq_table.drop(bind = engine)

self._create_table(dim)
# self.pg_table.drop(pg_engine, checkfirst=True)
pq_metadata.drop_all(pg_engine)
self._create_table(dim, pg_engine)


@classmethod
Expand All @@ -70,25 +84,53 @@ def init(self) -> None:
>>> self.insert_embeddings()
>>> self.search_embedding()
"""
self.pq_session = Session(self.pg_engine)
self.pg_engine = create_engine(**self.db_config)

Base = declarative_base()
pq_metadata = Base.metadata
pq_metadata.reflect(self.pg_engine)
self.pg_session = Session(self.pg_engine)
self.pg_table = self._get_table_schema(pq_metadata)
yield
self.pq_session = None
del (self.pq_session)
self.pg_session = None
self.pg_engine = None
del (self.pg_session)
del (self.pg_engine)

def ready_to_load(self):
pass

def optimize(self):
pass

def ready_to_search(self):
pass

def _get_table_schema(self, pq_metadata):
return Table(
self.table_name,
pq_metadata,
Column(self._primary_field, Integer, primary_key=True),
Column(self._vector_field, Vector(self.dim)),
extend_existing=True
)

def _create_index(self):
index = Index(self._index_name, self.pq_table.embedding, **self.case_config.index_param())
index.create(self.pg_engine)
def _create_index(self, pg_engine):
index_param = self.case_config.index_param()
index = Index(self._index_name, self.pg_table.c.embedding,
postgresql_using='ivfflat',
postgresql_with={'lists': index_param["lists"]},
postgresql_ops={'embedding': index_param["metric"]}
)
index.drop(pg_engine, checkfirst = True)
index.create(pg_engine)

def _create_table(self, dim : int):
def _create_table(self, dim, pg_engine : int):
try:
self.pg_table.create(bind = self.pg_engine, checkfirst = True)
self._create_index()
# create table
self.pg_table.create(bind = pg_engine, checkfirst = True)
# create vec index
self._create_index(pg_engine)
except Exception as e:
log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
raise e from None
Expand All @@ -100,10 +142,10 @@ def insert_embeddings(
**kwargs: Any,
) -> (int, Exception):
try:
items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(metadata)]
self.pq_session.execute(insert(table), items)
self.pq_session.commit()
return len(items), None
items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))]
self.pg_session.execute(insert(self.pg_table), items)
self.pg_session.commit()
return len(metadata), None
except Exception as e:
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
return 0, e
Expand All @@ -114,16 +156,16 @@ def search_embedding(
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> list[int]:
assert self.pq_table is not None
with self.pg_engine as conn:
conn.execute(text(f'SET ivfflat.probes = {kwargs["probes"]}'))
assert self.pg_table is not None
search_param =self.case_config.search_param()
with self.pg_engine.connect() as conn:
conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}'))
conn.commit()
op_fun = getattr(table.c.embedding, kwargs["metric_fun"])
op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"])
if filters:
res = self.pq_session.scalars(select(self.pq_table.order_by(op_fun(query)).filter(self.pq_table.c.id > filters.get('id')).limit(k)))
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k))
else:
res = self.pq_session.scalars(select(self.pq_table.order_by(op_fun(query)).limit(k)))
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k))
return list(res)

0 comments on commit e65413e

Please sign in to comment.