Skip to content

Commit

Permalink
enhance: Remove last batch in insert_embeddings
Browse files Browse the repository at this point in the history
Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn committed Apr 30, 2024
1 parent 0d75990 commit a3134fb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 32 deletions.
5 changes: 2 additions & 3 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def init(self) -> None:
connections.disconnect("default")

def _optimize(self):
self._post_insert()
log.info(f"{self.name} optimizing before search")
try:
self.col.load()
Expand Down Expand Up @@ -116,7 +117,7 @@ def wait_index():
time.sleep(5)

wait_index()

# Skip compaction if use GPU indexType
if self.case_config.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]:
log.debug("skip compaction for gpu index type.")
Expand Down Expand Up @@ -179,8 +180,6 @@ def insert_embeddings(
]
res = self.col.insert(insert_data)
insert_count += len(res.primary_keys)
if kwargs.get("last_batch"):
self._post_insert()
except MilvusException as e:
log.info(f"Failed to insert data: {e}")
return (insert_count, e)
Expand Down
45 changes: 19 additions & 26 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..api import IndexType, VectorDB, DBCaseConfig

log = logging.getLogger(__name__)
log = logging.getLogger(__name__)

class PgVector(VectorDB):
""" Use SQLAlchemy instructions"""
Expand All @@ -36,20 +36,20 @@ def __init__(
# construct basic units
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()
self.cursor = self.conn.cursor()

# create vector extension
self.cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
self.conn.commit()

if drop_old :
log.info(f"Pgvector client drop table : {self.table_name}")
# self.pg_table.drop(pg_engine, checkfirst=True)
self._drop_index()
self._drop_table()
self._create_table(dim)
self._create_index()

self.cursor.close()
self.conn.close()
self.cursor = None
Expand All @@ -66,47 +66,44 @@ def init(self) -> None:
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()

try:
yield
finally:
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None

def _drop_table(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"')
self.conn.commit()

def ready_to_load(self):
pass

def optimize(self):
pass

def _post_insert(self):
log.info(f"{self.name} post insert before optimize")
log.info(f"{self.name} optimizing")
self._drop_index()
self._create_index()

def ready_to_search(self):
pass

def _drop_index(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"')
self.conn.commit()

def _create_index(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
if self.case_config.index == IndexType.HNSW:
log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}')
Expand All @@ -117,11 +114,11 @@ def _create_index(self):
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()

def _create_table(self, dim : int):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
# create table
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" (id BIGINT PRIMARY KEY, embedding vector({dim}));')
Expand Down Expand Up @@ -151,16 +148,13 @@ def insert_embeddings(
csv_buffer.seek(0)
self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
self.conn.commit()

if kwargs.get("last_batch"):
self._post_insert()


return len(metadata), None
except Exception as e:
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
return 0, e

def search_embedding(
def search_embedding(
self,
query: list[float],
k: int = 100,
Expand All @@ -184,4 +178,3 @@ def search_embedding(
result = self.cursor.fetchall()

return [int(i[0]) for i in result]

2 changes: 0 additions & 2 deletions vectordb_bench/backend/runner/serial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ def task(self) -> int:
del(emb_np)
log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")

last_batch = self.dataset.data.size - count == len(all_metadata)
insert_count, error = self.db.insert_embeddings(
embeddings=all_embeddings,
metadata=all_metadata,
last_batch=last_batch,
)
if error is not None:
raise error
Expand Down
2 changes: 1 addition & 1 deletion vectordb_bench/backend/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric:
)

self._init_search_runner()
m.recall, m.serial_latency_p99 = self._serial_search()
m.qps = self._conc_search()
m.recall, m.serial_latency_p99 = self._serial_search()
except Exception as e:
log.warning(f"Failed to run performance case, reason = {e}")
traceback.print_exc()
Expand Down

0 comments on commit a3134fb

Please sign in to comment.