-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: yangxuan <[email protected]>
- Loading branch information
1 parent
1ab46dd
commit 49bccd1
Showing
9 changed files
with
413 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Iterable | ||
import argparse | ||
from vectordb_bench.backend.dataset import Dataset, DatasetSource | ||
from vectordb_bench.backend.runner.rate_runner import RatedMultiThreadingInsertRunner | ||
from vectordb_bench.backend.runner.read_write_runner import ReadWriteRunner | ||
from vectordb_bench.backend.clients import DB, VectorDB | ||
from vectordb_bench.backend.clients.milvus.config import FLATConfig | ||
from vectordb_bench.backend.clients.zilliz_cloud.config import AutoIndexConfig | ||
|
||
import logging | ||
|
||
log = logging.getLogger("vectordb_bench") | ||
log.setLevel(logging.DEBUG) | ||
|
||
def get_rate_runner(db): | ||
cohere = Dataset.COHERE.manager(100_000) | ||
prepared = cohere.prepare(DatasetSource.AliyunOSS) | ||
assert prepared | ||
runner = RatedMultiThreadingInsertRunner( | ||
rate = 10, | ||
db = db, | ||
dataset = cohere, | ||
) | ||
|
||
return runner | ||
|
||
def test_rate_runner(db, insert_rate): | ||
runner = get_rate_runner(db) | ||
|
||
_, t = runner.run_with_rate() | ||
log.info(f"insert run done, time={t}") | ||
|
||
def test_read_write_runner(db, insert_rate, conc: list, search_stage: Iterable[float], read_dur_after_write: int, local: bool=False): | ||
cohere = Dataset.COHERE.manager(1_000_000) | ||
if local is True: | ||
source = DatasetSource.AliyunOSS | ||
else: | ||
source = DatasetSource.S3 | ||
prepared = cohere.prepare(source) | ||
assert prepared | ||
|
||
rw_runner = ReadWriteRunner( | ||
db=db, | ||
dataset=cohere, | ||
insert_rate=insert_rate, | ||
search_stage=search_stage, | ||
read_dur_after_write=read_dur_after_write, | ||
concurrencies=conc | ||
) | ||
rw_runner.run_read_write() | ||
|
||
|
||
def get_db(db: str, config: dict) -> VectorDB: | ||
if db == DB.Milvus.name: | ||
return DB.Milvus.init_cls(dim=768, db_config=config, db_case_config=FLATConfig(metric_type="COSINE"), drop_old=True, pre_load=True) | ||
elif db == DB.ZillizCloud.name: | ||
return DB.ZillizCloud.init_cls(dim=768, db_config=config, db_case_config=AutoIndexConfig(metric_type="COSINE"), drop_old=True, pre_load=True) | ||
else: | ||
raise ValueError(f"unknown db: {db}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-r", "--insert_rate", type=int, default="1000", help="insert entity row count per seconds, cps") | ||
parser.add_argument("-d", "--db", type=str, default=DB.Milvus.name, help="db name") | ||
parser.add_argument("-t", "--duration", type=int, default=300, help="stage search duration in seconds") | ||
parser.add_argument("--use_s3", action='store_true', help="whether to use S3 dataset") | ||
|
||
flags = parser.parse_args() | ||
|
||
# TODO read uri, user, password from .env | ||
config = { | ||
"uri": "http://localhost:19530", | ||
"user": "", | ||
"password": "", | ||
} | ||
|
||
conc = (1, 15, 50) | ||
search_stage = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0) | ||
|
||
db = get_db(flags.db, config) | ||
test_read_write_runner( | ||
db=db, | ||
insert_rate=flags.insert_rate, | ||
conc=conc, | ||
search_stage=search_stage, | ||
read_dur_after_write=flags.duration, | ||
local=flags.use_s3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import logging | ||
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
import multiprocessing as mp | ||
|
||
|
||
from vectordb_bench.backend.clients import api | ||
from vectordb_bench.backend.dataset import DataSetIterator | ||
from vectordb_bench.backend.utils import time_it | ||
from vectordb_bench import config | ||
|
||
from .util import get_data, is_futures_completed, get_future_exceptions | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
class RatedMultiThreadingInsertRunner: | ||
def __init__( | ||
self, | ||
rate: int, # numRows per second | ||
db: api.VectorDB, | ||
dataset_iter: DataSetIterator, | ||
normalize: bool = False, | ||
timeout: float | None = None, | ||
): | ||
self.timeout = timeout if isinstance(timeout, (int, float)) else None | ||
self.dataset = dataset_iter | ||
self.db = db | ||
self.normalize = normalize | ||
self.insert_rate = rate | ||
self.batch_rate = rate // config.NUM_PER_BATCH | ||
|
||
def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]): | ||
db.insert_embeddings(emb, metadata) | ||
|
||
@time_it | ||
def run_with_rate(self, q: mp.Queue): | ||
with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor: | ||
executing_futures = [] | ||
|
||
@time_it | ||
def submit_by_rate() -> bool: | ||
rate = self.batch_rate | ||
for data in self.dataset: | ||
emb, metadata = get_data(data, self.normalize) | ||
executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata)) | ||
rate -= 1 | ||
|
||
if rate == 0: | ||
return False | ||
return rate == self.batch_rate | ||
|
||
with self.db.init(): | ||
while True: | ||
start_time = time.perf_counter() | ||
finished, elapsed_time = submit_by_rate() | ||
if finished is True: | ||
q.put(None, block=True) | ||
log.info(f"End of dataset, left unfinished={len(executing_futures)}") | ||
return | ||
|
||
q.put(True, block=False) | ||
wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001 | ||
|
||
e, completed = is_futures_completed(executing_futures, wait_interval) | ||
if completed is True: | ||
ex = get_future_exceptions(executing_futures) | ||
if ex is not None: | ||
log.warn(f"task error, terminating, err={ex}") | ||
q.put(None) | ||
executor.shutdown(wait=True, cancel_futures=True) | ||
raise ex | ||
else: | ||
log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}") | ||
executing_futures = [] | ||
else: | ||
log.warning(f"Failed to finish tasks in 1s, {e}, waited={wait_interval:.2f}, try to check the next round") | ||
dur = time.perf_counter() - start_time | ||
if dur < 1: | ||
time.sleep(1 - dur) |
Oops, something went wrong.