Skip to content

Commit

Permalink
enhance: refine read write cases
Browse files Browse the repository at this point in the history
1. Control search time during 2 insert stages, make sure search donesn't shift
away from insert proportions.
2. Collect ndcg metric
3. Optimize, serial search and conc search after insertion

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored and alwayslove2013 committed Dec 11, 2024
1 parent 854278a commit a8fdc1a
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 79 deletions.
6 changes: 3 additions & 3 deletions tests/test_rate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_read_write_runner(db, insert_rate, conc: list, search_stage: Iterable[f

def get_db(db: str, config: dict) -> VectorDB:
if db == DB.Milvus.name:
return DB.Milvus.init_cls(dim=768, db_config=config, db_case_config=FLATConfig(metric_type="COSINE"), drop_old=True, pre_load=True)
return DB.Milvus.init_cls(dim=768, db_config=config, db_case_config=FLATConfig(metric_type="COSINE"), drop_old=True)
elif db == DB.ZillizCloud.name:
return DB.ZillizCloud.init_cls(dim=768, db_config=config, db_case_config=AutoIndexConfig(metric_type="COSINE"), drop_old=True, pre_load=True)
return DB.ZillizCloud.init_cls(dim=768, db_config=config, db_case_config=AutoIndexConfig(metric_type="COSINE"), drop_old=True)
else:
raise ValueError(f"unknown db: {db}")

Expand All @@ -76,7 +76,7 @@ def get_db(db: str, config: dict) -> VectorDB:
}

conc = (1, 15, 50)
search_stage = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0)
search_stage = (0.5, 0.6, 0.7, 0.8, 0.9)

db = get_db(flags.db, config)
test_read_write_runner(
Expand Down
12 changes: 5 additions & 7 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymilvus import Collection, utility
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException

from ..api import VectorDB, IndexType
from ..api import VectorDB
from .config import MilvusIndexConfig


Expand Down Expand Up @@ -66,8 +66,7 @@ def __init__(
self.case_config.index_param(),
index_name=self._index_name,
)
if kwargs.get("pre_load") is True:
self._pre_load(col)
col.load()

connections.disconnect("default")

Expand All @@ -90,16 +89,15 @@ def init(self) -> None:
connections.disconnect("default")

def _optimize(self):
self._post_insert()
log.info(f"{self.name} optimizing before search")
self._post_insert()
try:
self.col.load()
self.col.load(refresh=True)
except Exception as e:
log.warning(f"{self.name} optimize error: {e}")
raise e from None

def _post_insert(self):
log.info(f"{self.name} post insert before optimize")
try:
self.col.flush()
# wait for index done and load refresh
Expand Down Expand Up @@ -130,7 +128,7 @@ def wait_index():
log.warning(f"{self.name} compact error: {e}")
if hasattr(e, 'code'):
if e.code().name == 'PERMISSION_DENIED':
log.warning(f"Skip compact due to permission denied.")
log.warning("Skip compact due to permission denied.")
pass
else:
raise e
Expand Down
47 changes: 32 additions & 15 deletions vectordb_bench/backend/runner/rate_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import time
import concurrent
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp

Expand All @@ -9,7 +10,7 @@
from vectordb_bench.backend.utils import time_it
from vectordb_bench import config

from .util import get_data, is_futures_completed, get_future_exceptions
from .util import get_data
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,26 +55,42 @@ def submit_by_rate() -> bool:
start_time = time.perf_counter()
finished, elapsed_time = submit_by_rate()
if finished is True:
q.put(None, block=True)
q.put(True, block=True)
log.info(f"End of dataset, left unfinished={len(executing_futures)}")
return
break

q.put(True, block=False)
q.put(False, block=False)
wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001

e, completed = is_futures_completed(executing_futures, wait_interval)
if completed is True:
ex = get_future_exceptions(executing_futures)
if ex is not None:
log.warn(f"task error, terminating, err={ex}")
q.put(None)
executor.shutdown(wait=True, cancel_futures=True)
raise ex
try:
done, not_done = concurrent.futures.wait(
executing_futures,
timeout=wait_interval,
return_when=concurrent.futures.FIRST_EXCEPTION)

if len(not_done) > 0:
log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round")
executing_futures = list(not_done)
else:
log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
executing_futures = []
else:
log.warning(f"Failed to finish tasks in 1s, {e}, waited={wait_interval:.2f}, try to check the next round")
executing_futures = []
except Exception as e:
log.warn(f"task error, terminating, err={e}")
q.put(None, block=True)
executor.shutdown(wait=True, cancel_futures=True)
raise e

dur = time.perf_counter() - start_time
if dur < 1:
time.sleep(1 - dur)

# wait for all tasks in executing_futures to complete
if len(executing_futures) > 0:
try:
done, _ = concurrent.futures.wait(executing_futures,
return_when=concurrent.futures.FIRST_EXCEPTION)
except Exception as e:
log.warn(f"task error, terminating, err={e}")
q.put(None, block=True)
executor.shutdown(wait=True, cancel_futures=True)
raise e
138 changes: 102 additions & 36 deletions vectordb_bench/backend/runner/read_write_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def __init__(
k: int = 100,
filters: dict | None = None,
concurrencies: Iterable[int] = (1, 15, 50),
search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0), # search in any insert portion, 0.0 means search from the start
search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9), # search from insert portion, 0.0 means search from the start
read_dur_after_write: int = 300, # seconds, search duration when insertion is done
timeout: float | None = None,
):
self.insert_rate = insert_rate
self.data_volume = dataset.data.size

for stage in search_stage:
assert 0.0 <= stage <= 1.0, "each search stage should be in [0.0, 1.0]"
assert 0.0 <= stage < 1.0, "each search stage should be in [0.0, 1.0)"
self.search_stage = sorted(search_stage)
self.read_dur_after_write = read_dur_after_write

Expand Down Expand Up @@ -65,48 +65,114 @@ def __init__(
k=k,
)

def run_optimize(self):
"""Optimize needs to run in differenct process for pymilvus schema recursion problem"""
with self.db.init():
log.info("Search after write - Optimize start")
self.db.optimize()
log.info("Search after write - Optimize finished")

def run_search(self):
log.info("Search after write - Serial search start")
res, ssearch_dur = self.serial_search_runner.run()
recall, ndcg, p99_latency = res
log.info(f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}")
max_qps = self.run_by_dur(self.read_dur_after_write)
log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")

return (max_qps, recall, ndcg, p99_latency)

def run_read_write(self):
futures = []
with mp.Manager() as m:
q = m.Queue()
with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
futures.append(executor.submit(self.run_with_rate, q))
futures.append(executor.submit(self.run_search_by_sig, q))

for future in concurrent.futures.as_completed(futures):
res = future.result()
log.info(f"Result = {res}")

read_write_futures = []
read_write_futures.append(executor.submit(self.run_with_rate, q))
read_write_futures.append(executor.submit(self.run_search_by_sig, q))

try:
for f in concurrent.futures.as_completed(read_write_futures):
res = f.result()
log.info(f"Result = {res}")

# Wait for read_write_futures finishing and do optimize and search
op_future = executor.submit(self.run_optimize)
op_future.result()

search_future = executor.submit(self.run_search)
last_res = search_future.result()

log.info(f"Max QPS after optimze and search: {last_res}")
except Exception as e:
log.warning(f"Read and write error: {e}")
executor.shutdown(wait=True, cancel_futures=True)
raise e
log.info("Concurrent read write all done")


def run_search_by_sig(self, q):
res = []
"""
Args:
q: multiprocessing queue
(None) means abnormal exit
(False) means updating progress
(True) means normal exit
"""
result, start_batch = [], 0
total_batch = math.ceil(self.data_volume / self.insert_rate)
batch = 0
recall = 'x'
recall, ndcg, p99_latency = None, None, None

def wait_next_target(start, target_batch) -> bool:
"""Return False when receive True or None"""
while start < target_batch:
sig = q.get(block=True)

if sig is None or sig is True:
return False
else:
start += 1
return True

for idx, stage in enumerate(self.search_stage):
target_batch = int(total_batch * stage)
while q.get(block=True):
batch += 1
if batch >= target_batch:
perc = int(stage * 100)
log.info(f"Insert {perc}% done, total batch={total_batch}")
log.info(f"[{batch}/{total_batch}] Serial search - {perc}% start")
recall, ndcg, p99 =self.serial_search_runner.run()

if idx < len(self.search_stage) - 1:
stage_search_dur = (self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate) // len(self.concurrencies)
if stage_search_dur < 30:
log.warning(f"Search duration too short, please reduce concurrency count or insert rate, or increase dataset volume: dur={stage_search_dur}, concurrencies={len(self.concurrencies)}, insert_rate={self.insert_rate}")
log.info(f"[{batch}/{total_batch}] Conc search - {perc}% start, dur for each conc={stage_search_dur}s")
else:
last_search_dur = self.data_volume * (1.0 - stage) // self.insert_rate
stage_search_dur = last_search_dur + self.read_dur_after_write
log.info(f"[{batch}/{total_batch}] Last conc search - {perc}% start, [read_until_write|read_after_write|total] =[{last_search_dur}s|{self.read_dur_after_write}s|{stage_search_dur}s]")

max_qps = self.run_by_dur(stage_search_dur)
res.append((perc, max_qps, recall))
break
return res
perc = int(stage * 100)

got = wait_next_target(start_batch, target_batch)
if got is False:
log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}")
return

log.info(f"Insert {perc}% done, total batch={total_batch}")
log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
res, ssearch_dur = self.serial_search_runner.run()
recall, ndcg, p99_latency = res
log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")

# Search duration for non-last search stage is carefully calculated.
# If duration for each concurrency is less than 30s, runner will raise error.
if idx < len(self.search_stage) - 1:
total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
csearch_dur = total_dur_between_stages - ssearch_dur

# Try to leave room for init process executors
csearch_dur = csearch_dur - 30 if csearch_dur > 60 else csearch_dur

each_conc_search_dur = csearch_dur / len(self.concurrencies)
if each_conc_search_dur < 30:
warning_msg = f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}."
log.warning(warning_msg)

# The last stage
else:
each_conc_search_dur = 60

log.info(f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}")
max_qps = self.run_by_dur(each_conc_search_dur)
result.append((perc, max_qps, recall, ndcg, p99_latency))

start_batch = target_batch

# Drain the queue
while q.empty() is False:
q.get(block=True)
return result
10 changes: 8 additions & 2 deletions vectordb_bench/backend/runner/serial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
self.test_data = test_data
self.ground_truth = ground_truth

def search(self, args: tuple[list, pd.DataFrame]):
def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
with self.db.init():
test_data, ground_truth = args
Expand Down Expand Up @@ -224,5 +224,11 @@ def _run_in_subprocess(self) -> tuple[float, float]:
result = future.result()
return result

def run(self) -> tuple[float, float]:
@utils.time_it
def run(self) -> tuple[float, float, float]:
"""
Returns:
tuple[tuple[float, float, float], float]: (avg_recall, avg_ndcg, p99_latency), cost
"""
return self._run_in_subprocess()
16 changes: 0 additions & 16 deletions vectordb_bench/backend/runner/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging
import concurrent
from typing import Iterable

from pandas import DataFrame
import numpy as np
Expand All @@ -16,17 +14,3 @@ def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], li
else:
all_embeddings = emb_np.tolist()
return all_embeddings, all_metadata

def is_futures_completed(futures: Iterable[concurrent.futures.Future], interval) -> (Exception, bool):
try:
list(concurrent.futures.as_completed(futures, timeout=interval))
except TimeoutError as e:
return e, False
return None, True


def get_future_exceptions(futures: Iterable[concurrent.futures.Future]) -> BaseException | None:
for f in futures:
if f.exception() is not None:
return f.exception()
return
1 change: 1 addition & 0 deletions vectordb_bench/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def numerize(n) -> str:


def time_it(func):
""" returns result and elapsed time"""
@wraps(func)
def inner(*args, **kwargs):
pref = time.perf_counter()
Expand Down

0 comments on commit a8fdc1a

Please sign in to comment.