Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Remove last batch in insert_embeddings #314

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def init(self) -> None:
connections.disconnect("default")

def _optimize(self):
self._post_insert()
log.info(f"{self.name} optimizing before search")
try:
self.col.load()
Expand Down Expand Up @@ -116,7 +117,7 @@ def wait_index():
time.sleep(5)

wait_index()

# Skip compaction if use GPU indexType
if self.case_config.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]:
log.debug("skip compaction for gpu index type.")
Expand Down Expand Up @@ -179,8 +180,6 @@ def insert_embeddings(
]
res = self.col.insert(insert_data)
insert_count += len(res.primary_keys)
if kwargs.get("last_batch"):
self._post_insert()
except MilvusException as e:
log.info(f"Failed to insert data: {e}")
return (insert_count, e)
Expand Down
45 changes: 19 additions & 26 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..api import IndexType, VectorDB, DBCaseConfig

log = logging.getLogger(__name__)
log = logging.getLogger(__name__)

class PgVector(VectorDB):
""" Use SQLAlchemy instructions"""
Expand All @@ -36,20 +36,20 @@ def __init__(
# construct basic units
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()
self.cursor = self.conn.cursor()

# create vector extension
self.cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
self.conn.commit()

if drop_old :
log.info(f"Pgvector client drop table : {self.table_name}")
# self.pg_table.drop(pg_engine, checkfirst=True)
self._drop_index()
self._drop_table()
self._create_table(dim)
self._create_index()

self.cursor.close()
self.conn.close()
self.cursor = None
Expand All @@ -66,47 +66,44 @@ def init(self) -> None:
self.conn = psycopg2.connect(**self.db_config)
self.conn.autocommit = False
self.cursor = self.conn.cursor()

try:
yield
finally:
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None

def _drop_table(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"')
self.conn.commit()

def ready_to_load(self):
pass

def optimize(self):
pass

def _post_insert(self):
log.info(f"{self.name} post insert before optimize")
log.info(f"{self.name} optimizing")
self._drop_index()
self._create_index()

def ready_to_search(self):
pass

def _drop_index(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"')
self.conn.commit()

def _create_index(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
if self.case_config.index == IndexType.HNSW:
log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}')
Expand All @@ -117,11 +114,11 @@ def _create_index(self):
else:
assert "Invalid index type {self.case_config.index}"
self.conn.commit()

def _create_table(self, dim : int):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
# create table
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" (id BIGINT PRIMARY KEY, embedding vector({dim}));')
Expand Down Expand Up @@ -151,16 +148,13 @@ def insert_embeddings(
csv_buffer.seek(0)
self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
self.conn.commit()

if kwargs.get("last_batch"):
self._post_insert()


return len(metadata), None
except Exception as e:
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
return 0, e

def search_embedding(
def search_embedding(
self,
query: list[float],
k: int = 100,
Expand All @@ -184,4 +178,3 @@ def search_embedding(
result = self.cursor.fetchall()

return [int(i[0]) for i in result]

2 changes: 0 additions & 2 deletions vectordb_bench/backend/runner/serial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ def task(self) -> int:
del(emb_np)
log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")

last_batch = self.dataset.data.size - count == len(all_metadata)
insert_count, error = self.db.insert_embeddings(
embeddings=all_embeddings,
metadata=all_metadata,
last_batch=last_batch,
)
if error is not None:
raise error
Expand Down
2 changes: 1 addition & 1 deletion vectordb_bench/backend/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric:
)

self._init_search_runner()
m.recall, m.serial_latency_p99 = self._serial_search()
m.qps = self._conc_search()
m.recall, m.serial_latency_p99 = self._serial_search()
except Exception as e:
log.warning(f"Failed to run performance case, reason = {e}")
traceback.print_exc()
Expand Down