From fd2b1860d7c8a10a415b43afb3b980c116388a6e Mon Sep 17 00:00:00 2001 From: yangxuan Date: Thu, 18 Jan 2024 11:16:15 +0800 Subject: [PATCH] Enable aliyun OSS Add data_source.py, vdb bench now can download dataset from Aliyun OSS. Signed-off-by: yangxuan --- pyproject.toml | 1 + tests/test_data_source.py | 78 ++++++++++ tests/test_dataset.py | 33 ++++- vectordb_bench/__init__.py | 6 +- vectordb_bench/backend/assembler.py | 14 +- vectordb_bench/backend/data_source.py | 204 ++++++++++++++++++++++++++ vectordb_bench/backend/dataset.py | 195 ++++++++++-------------- vectordb_bench/backend/task_runner.py | 6 +- vectordb_bench/interface.py | 16 +- vectordb_bench/log_util.py | 1 - 10 files changed, 418 insertions(+), 136 deletions(-) create mode 100644 tests/test_data_source.py create mode 100644 vectordb_bench/backend/data_source.py diff --git a/pyproject.toml b/pyproject.toml index f73bc2940..075ec92be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "streamlit_extras", "tqdm", "s3fs", + "oss2", "psutil", "polars", "plotly", diff --git a/tests/test_data_source.py b/tests/test_data_source.py new file mode 100644 index 000000000..a2b1d0ed8 --- /dev/null +++ b/tests/test_data_source.py @@ -0,0 +1,78 @@ +import logging +import pathlib +import pytest +from vectordb_bench.backend.data_source import AliyunOSSReader, AwsS3Reader +from vectordb_bench.backend.dataset import Dataset, DatasetManager + +log = logging.getLogger(__name__) + +class TestReader: + @pytest.mark.parametrize("size", [ + 100_000, + 1_000_000, + 10_000_000, + ]) + def test_cohere(self, size): + cohere = Dataset.COHERE.manager(size) + self.per_dataset_test(cohere) + + @pytest.mark.parametrize("size", [ + 100_000, + 1_000_000, + ]) + def test_gist(self, size): + gist = Dataset.GIST.manager(size) + self.per_dataset_test(gist) + + @pytest.mark.parametrize("size", [ + 1_000_000, + ]) + def test_glove(self, size): + glove = Dataset.GLOVE.manager(size) + self.per_dataset_test(glove) + + @pytest.mark.parametrize("size", [ + 500_000, + 5_000_000, + # 50_000_000, + ]) + def test_sift(self, size): + sift = Dataset.SIFT.manager(size) + self.per_dataset_test(sift) + + @pytest.mark.parametrize("size", [ + 50_000, + 500_000, + 5_000_000, + ]) + def test_openai(self, size): + openai = Dataset.OPENAI.manager(size) + self.per_dataset_test(openai) + + + def per_dataset_test(self, dataset: DatasetManager): + s3_reader = AwsS3Reader() + all_files = s3_reader.ls_all(dataset.data.dir_name) + + + remote_f_names = [] + for file in all_files: + remote_f = pathlib.Path(file).name + if dataset.data.use_shuffled and remote_f.startswith("train"): + continue + + elif (not dataset.data.use_shuffled) and remote_f.startswith("shuffle"): + continue + + remote_f_names.append(remote_f) + + + assert set(dataset.data.files) == set(remote_f_names) + + aliyun_reader = AliyunOSSReader() + for fname in dataset.data.files: + p = pathlib.Path("benchmark", dataset.data.dir_name, fname) + assert aliyun_reader.bucket.object_exists(p.as_posix()) + + log.info(f"downloading to {dataset.data_dir}") + aliyun_reader.read(dataset.data.dir_name.lower(), dataset.data.files, dataset.data_dir) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f0578c218..56c3715e5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -from vectordb_bench.backend.dataset import Dataset +from vectordb_bench.backend.dataset import Dataset, get_files import logging import pytest from pydantic import ValidationError @@ -34,3 +34,34 @@ def test_iter_cohere(self): for i in cohere_10m: log.debug(i.head(1)) + +class TestGetFiles: + @pytest.mark.parametrize("train_count", [ + 1, + 10, + 50, + 100, + ]) + @pytest.mark.parametrize("with_gt", [True, False]) + def test_train_count(self, train_count, with_gt): + files = get_files(train_count, True, with_gt) + log.info(files) + + if with_gt: + assert len(files) - 4 == train_count + else: + assert len(files) - 1 == train_count + + @pytest.mark.parametrize("use_shuffled", [True, False]) + def test_use_shuffled(self, use_shuffled): + files = get_files(1, use_shuffled, True) + log.info(files) + + trains = [f for f in files if "train" in f] + if use_shuffled: + for t in trains: + assert "shuffle_train" in t + else: + for t in trains: + assert "shuffle" not in t + assert "train" in t diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index b31c75358..eca190832 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -8,10 +8,12 @@ env.read_env(".env") class config: + ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/" + AWS_S3_URL = "assets.zilliz.com/benchmark/" + LOG_LEVEL = env.str("LOG_LEVEL", "INFO") - DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com/benchmark/") - DEFAULT_DATASET_URL_ALIYUN = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com.cn/benchmark/") + DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", AWS_S3_URL) DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset") NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000) diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py index 6aaec0b63..6b0e3c81d 100644 --- a/vectordb_bench/backend/assembler.py +++ b/vectordb_bench/backend/assembler.py @@ -2,6 +2,7 @@ from .task_runner import CaseRunner, RunningStatus, TaskRunner from ..models import TaskConfig from ..backend.clients import EmptyDBCaseConfig +from ..backend.data_source import DatasetSource import logging @@ -10,7 +11,7 @@ class Assembler: @classmethod - def assemble(cls, run_id , task: TaskConfig) -> CaseRunner: + def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner: c_cls = task.case_config.case_id.case_cls c = c_cls() @@ -22,14 +23,21 @@ def assemble(cls, run_id , task: TaskConfig) -> CaseRunner: config=task, ca=c, status=RunningStatus.PENDING, + dataset_source=source, ) return runner @classmethod - def assemble_all(cls, run_id: str, task_label: str, tasks: list[TaskConfig]) -> TaskRunner: + def assemble_all( + cls, + run_id: str, + task_label: str, + tasks: list[TaskConfig], + source: DatasetSource, + ) -> TaskRunner: """group by case type, db, and case dataset""" - runners = [cls.assemble(run_id, task) for task in tasks] + runners = [cls.assemble(run_id, task, source) for task in tasks] load_runners = [r for r in runners if r.ca.label == CaseLabel.Load] perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance] diff --git a/vectordb_bench/backend/data_source.py b/vectordb_bench/backend/data_source.py new file mode 100644 index 000000000..0398ec653 --- /dev/null +++ b/vectordb_bench/backend/data_source.py @@ -0,0 +1,204 @@ +import logging +import pathlib +import typing +from enum import Enum +from tqdm import tqdm +from hashlib import md5 +import os +from abc import ABC, abstractmethod + +from .. import config + +logging.getLogger("s3fs").setLevel(logging.CRITICAL) + +log = logging.getLogger(__name__) + +DatasetReader = typing.TypeVar("DatasetReader") + +class DatasetSource(Enum): + S3 = "S3" + AliyunOSS = "AliyunOSS" + + def reader(self) -> DatasetReader: + if self == DatasetSource.S3: + return AwsS3Reader() + + if self == DatasetSource.AliyunOSS: + return AliyunOSSReader() + + +class DatasetReader(ABC): + source: DatasetSource + remote_root: str + + @abstractmethod + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): + """read dataset files from remote_root to local_ds_root, + + Args: + dataset(str): for instance "sift_small_500k" + files(list[str]): all filenames of the dataset + local_ds_root(pathlib.Path): whether to write the remote data. + check_etag(bool): whether to check the etag + """ + pass + + @abstractmethod + def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: + pass + + +class AliyunOSSReader(DatasetReader): + source: DatasetSource = DatasetSource.AliyunOSS + remote_root: str = config.ALIYUN_OSS_URL + + def __init__(self): + import oss2 + self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True) + + def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: + info = self.bucket.get_object_meta(remote.as_posix()) + + # check size equal + remote_size, local_size = info.content_length, os.path.getsize(local) + if remote_size != local_size: + log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") + return False + + # check etag equal + if check_etag: + return match_etag(info.etag.strip('"').lower(), local) + + + return True + + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False): + downloads = [] + if not local_ds_root.exists(): + log.info(f"local dataset root path not exist, creating it: {local_ds_root}") + local_ds_root.mkdir(parents=True) + downloads = [(pathlib.Path("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files] + + else: + for file in files: + remote_file = pathlib.Path("benchmark", dataset, file) + local_file = local_ds_root.joinpath(file) + + if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)): + log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") + downloads.append((remote_file, local_file)) + + if len(downloads) == 0: + return + + log.info(f"Start to downloading files, total count: {len(downloads)}") + for remote_file, local_file in tqdm(downloads): + log.debug(f"downloading file {remote_file} to {local_ds_root}") + self.bucket.get_object_to_file(remote_file.as_posix(), local_file.as_posix()) + + log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") + + + +class AwsS3Reader(DatasetReader): + source: DatasetSource = DatasetSource.S3 + remote_root: str = config.AWS_S3_URL + + def __init__(self): + import s3fs + self.fs = s3fs.S3FileSystem( + anon=True, + client_kwargs={'region_name': 'us-west-2'} + ) + + def ls_all(self, dataset: str): + dataset_root_dir = pathlib.Path(self.remote_root, dataset) + log.info(f"listing dataset: {dataset_root_dir}") + names = self.fs.ls(dataset_root_dir) + for n in names: + log.info(n) + return names + + + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): + downloads = [] + if not local_ds_root.exists(): + log.info(f"local dataset root path not exist, creating it: {local_ds_root}") + local_ds_root.mkdir(parents=True) + downloads = [pathlib.Path(self.remote_root, dataset, f) for f in files] + + else: + for file in files: + remote_file = pathlib.Path(self.remote_root, dataset, file) + local_file = local_ds_root.joinpath(file) + + if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)): + log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") + downloads.append(remote_file) + + if len(downloads) == 0: + return + + log.info(f"Start to downloading files, total count: {len(downloads)}") + for s3_file in tqdm(downloads): + log.debug(f"downloading file {s3_file} to {local_ds_root}") + self.fs.download(s3_file, local_ds_root.as_posix()) + + log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") + + + def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: + # info() uses ls() inside, maybe we only need to ls once + info = self.fs.info(remote) + + # check size equal + remote_size, local_size = info.get("size"), os.path.getsize(local) + if remote_size != local_size: + log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") + return False + + # check etag equal + if check_etag: + return match_etag(info.get('ETag', "").strip('"'), local) + + return True + + +def match_etag(expected_etag: str, local_file) -> bool: + """Check if local files' etag match with S3""" + def factor_of_1MB(filesize, num_parts): + x = filesize / int(num_parts) + y = x % 1048576 + return int(x + 1048576 - y) + + def calc_etag(inputfile, partsize): + md5_digests = [] + with open(inputfile, 'rb') as f: + for chunk in iter(lambda: f.read(partsize), b''): + md5_digests.append(md5(chunk).digest()) + return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)) + + def possible_partsizes(filesize, num_parts): + return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts + + filesize = os.path.getsize(local_file) + le = "" + if '-' not in expected_etag: # no spliting uploading + with open(local_file, 'rb') as f: + le = md5(f.read()).hexdigest() + log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") + return expected_etag == le + else: + num_parts = int(expected_etag.split('-')[-1]) + partsizes = [ ## Default Partsizes Map + 8388608, # aws_cli/boto3 + 15728640, # s3cmd + factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files + ] + + for partsize in filter(possible_partsizes(filesize, num_parts), partsizes): + le = calc_etag(local_file, partsize) + log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") + if expected_etag == le: + return True + return False diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 8c2b0ceef..46c31bce5 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -4,14 +4,11 @@ >>> Dataset.Cohere.get(100_000) """ -import os +from collections import namedtuple import logging import pathlib -from hashlib import md5 from enum import Enum -import s3fs import pandas as pd -from tqdm import tqdm from pydantic import validator, PrivateAttr import polars as pl from pyarrow.parquet import ParquetFile @@ -20,17 +17,21 @@ from .. import config from ..backend.clients import MetricType from . import utils +from .data_source import DatasetSource, DatasetReader log = logging.getLogger(__name__) +SizeLabel = namedtuple('SizeLabel', ['size', 'label', 'files']) + + class BaseDataset(BaseModel): name: str size: int dim: int metric_type: MetricType use_shuffled: bool - _size_label: dict = PrivateAttr() + _size_label: dict[int, SizeLabel] = PrivateAttr() @validator("size") def verify_size(cls, v): @@ -40,19 +41,51 @@ def verify_size(cls, v): @property def label(self) -> str: - return self._size_label.get(self.size) + return self._size_label.get(self.size).label @property def dir_name(self) -> str: return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower() + @property + def files(self) -> str: + return self._size_label.get(self.size).files + + +def get_files(train_count: int, use_shuffled: bool, with_gt: bool = True) -> list[str]: + prefix = "shuffle_train" if use_shuffled else "train" + middle = f"of-{train_count}" + surfix = "parquet" + + train_files = [] + if train_count > 1: + just_size = len(str(train_count)) + for i in range(train_count): + sub_file = f"{prefix}-{str(i).rjust(just_size, '0')}-{middle}.{surfix}" + train_files.append(sub_file) + else: + train_files.append(f"{prefix}.{surfix}") + + files = ['test.parquet'] + if with_gt: + files.extend([ + 'neighbors.parquet', + 'neighbors_tail_1p.parquet', + 'neighbors_head_1p.parquet', + ]) + + files.extend(train_files) + return files + class LAION(BaseDataset): name: str = "LAION" dim: int = 768 metric_type: MetricType = MetricType.L2 use_shuffled: bool = False - _size_label: dict = {100_000_000: "LARGE"} + _size_label: dict = { + 100_000_000: SizeLabel(100_000_000, "LARGE", get_files(100, False)), + } class GIST(BaseDataset): @@ -61,8 +94,8 @@ class GIST(BaseDataset): metric_type: MetricType = MetricType.L2 use_shuffled: bool = False _size_label: dict = { - 100_000: "SMALL", - 1_000_000: "MEDIUM", + 100_000: SizeLabel(100_000, "SMALL", get_files(1, False, False)), + 1_000_000: SizeLabel(1_000_000, "MEDIUM", get_files(1, False, False)), } @@ -72,9 +105,9 @@ class Cohere(BaseDataset): metric_type: MetricType = MetricType.COSINE use_shuffled: bool = config.USE_SHUFFLED_DATA _size_label: dict = { - 100_000: "SMALL", - 1_000_000: "MEDIUM", - 10_000_000: "LARGE", + 100_000: SizeLabel(100_000, "SMALL", get_files(1, config.USE_SHUFFLED_DATA)), + 1_000_000: SizeLabel(1_000_000, "MEDIUM", get_files(1, config.USE_SHUFFLED_DATA)), + 10_000_000: SizeLabel(10_000_000, "LARGE", get_files(10, config.USE_SHUFFLED_DATA)), } @@ -83,7 +116,7 @@ class Glove(BaseDataset): dim: int = 200 metric_type: MetricType = MetricType.COSINE use_shuffled: bool = False - _size_label: dict = {1_000_000: "MEDIUM"} + _size_label: dict = {1_000_000: SizeLabel(1_000_000, "MEDIUM", get_files(1, False, False))} class SIFT(BaseDataset): @@ -92,25 +125,26 @@ class SIFT(BaseDataset): metric_type: MetricType = MetricType.L2 use_shuffled: bool = False _size_label: dict = { - 500_000: "SMALL", - 5_000_000: "MEDIUM", - 50_000_000: "LARGE", + 500_000: SizeLabel(500_000, "SMALL", get_files(1, False, False)), + 5_000_000: SizeLabel(5_000_000, "MEDIUM", get_files(1, False, False)), + # 50_000_000: SizeLabel(50_000_000, "LARGE", get_files(50, False, False)), } + class OpenAI(BaseDataset): name: str = "OpenAI" dim: int = 1536 metric_type: MetricType = MetricType.COSINE use_shuffled: bool = config.USE_SHUFFLED_DATA _size_label: dict = { - 50_000: "SMALL", - 500_000: "MEDIUM", - 5_000_000: "LARGE", + 50_000: SizeLabel(50_000, "SMALL", get_files(1, config.USE_SHUFFLED_DATA)), + 500_000: SizeLabel(500_000, "MEDIUM", get_files(1, config.USE_SHUFFLED_DATA)), + 5_000_000: SizeLabel(5_000_000, "LARGE", get_files(10, config.USE_SHUFFLED_DATA)), } class DatasetManager(BaseModel): - """Download dataset if not int the local directory. Provide data for cases. + """Download dataset if not in the local directory. Provide data for cases. DatasetManager is iterable, each iteration will return the next batch of data in pandas.DataFrame @@ -122,12 +156,16 @@ class DatasetManager(BaseModel): data: BaseDataset test_data: pd.DataFrame | None = None train_files : list[str] = [] + reader: DatasetReader | None = None def __eq__(self, obj): if isinstance(obj, DatasetManager): return self.data.name == obj.data.name and self.data.label == obj.data.label return False + def set_reader(self, reader: DatasetReader): + self.reader = reader + @property def data_dir(self) -> pathlib.Path: """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname} @@ -139,108 +177,12 @@ def data_dir(self) -> pathlib.Path: """ return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower()) - @property - def download_dir(self) -> str: - """ data s3 directory: config.DEFAULT_DATASET_URL/{dataset_dirname} - - Examples: - >>> sift_s = Dataset.SIFT.manager(500_000) - >>> sift_s.download_dir - 'assets.zilliz.com/benchmark/sift_small_500k' - """ - return f"{config.DEFAULT_DATASET_URL}{self.data.dir_name}" - def __iter__(self): return DataSetIterator(self) - def _validate_local_file(self): - if not self.data_dir.exists(): - log.info(f"local file path not exist, creating it: {self.data_dir}") - self.data_dir.mkdir(parents=True) - - fs = s3fs.S3FileSystem( - anon=True, - client_kwargs={'region_name': 'us-west-2'} - ) - dataset_info = fs.ls(self.download_dir, detail=True) - if len(dataset_info) == 0: - raise ValueError(f"No data in s3 for dataset: {self.download_dir}") - path2etag = {info['Key']: info['ETag'].split('"')[1] for info in dataset_info} - - perfix_to_filter = "train" if self.data.use_shuffled else "shuffle_train" - filtered_keys = [key for key in path2etag.keys() if key.split("/")[-1].startswith(perfix_to_filter)] - for k in filtered_keys: - path2etag.pop(k) - - # get local files ended with '.parquet' - file_names = [p.name for p in self.data_dir.glob("*.parquet")] - log.info(f"local files: {file_names}") - log.info(f"s3 files: {path2etag.keys()}") - downloads = [] - if len(file_names) == 0: - log.info("no local files, set all to downloading lists") - downloads = path2etag.keys() - else: - # if local file exists, check the etag of local file with s3, - # make sure data files aren't corrupted. - for name in tqdm([key.split("/")[-1] for key in path2etag.keys()]): - s3_path = f"{self.download_dir}/{name}" - local_path = self.data_dir.joinpath(name) - log.debug(f"s3 path: {s3_path}, local_path: {local_path}") - if not local_path.exists(): - log.info(f"local file not exists: {local_path}, add to downloading lists") - downloads.append(s3_path) - - elif not self.match_etag(path2etag.get(s3_path), local_path): - log.info(f"local file etag not match with s3 file: {local_path}, add to downloading lists") - downloads.append(s3_path) - - for s3_file in tqdm(downloads): - log.debug(f"downloading file {s3_file} to {self.data_dir}") - fs.download(s3_file, self.data_dir.as_posix()) - - def match_etag(self, expected_etag: str, local_file) -> bool: - """Check if local files' etag match with S3""" - def factor_of_1MB(filesize, num_parts): - x = filesize / int(num_parts) - y = x % 1048576 - return int(x + 1048576 - y) - - def calc_etag(inputfile, partsize): - md5_digests = [] - with open(inputfile, 'rb') as f: - for chunk in iter(lambda: f.read(partsize), b''): - md5_digests.append(md5(chunk).digest()) - return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)) - - def possible_partsizes(filesize, num_parts): - return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts - - filesize = os.path.getsize(local_file) - le = "" - if '-' not in expected_etag: # no spliting uploading - with open(local_file, 'rb') as f: - le = md5(f.read()).hexdigest() - log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") - return expected_etag == le - else: - num_parts = int(expected_etag.split('-')[-1]) - partsizes = [ ## Default Partsizes Map - 8388608, # aws_cli/boto3 - 15728640, # s3cmd - factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files - ] - - for partsize in filter(possible_partsizes(filesize, num_parts), partsizes): - le = calc_etag(local_file, partsize) - log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") - if expected_etag == le: - return True - return False - - def prepare(self, check=True) -> bool: - """Download the dataset from S3 - url = f"{config.DEFAULT_DATASET_URL}/{self.data.dir_name}" + def prepare(self, source: DatasetSource=DatasetSource.S3, check: bool=True) -> bool: + """Download the dataset from DatasetSource + url = f"{source}/{self.data.dir_name}" download files from url to self.data_dir, there'll be 4 types of files in the data_dir - train*.parquet: for training @@ -248,9 +190,20 @@ def prepare(self, check=True) -> bool: - neighbors.parquet: ground_truth of the test.parquet - neighbors_head_1p.parquet: ground_truth of the test.parquet after filtering 1% data - neighbors_99p.parquet: ground_truth of the test.parquet after filtering 99% data + + Args: + source(DatasetSource): S3 or AliyunOSS, default as S3 + check(bool): Whether to do etags check + + Returns: + bool: whether the dataset is successfully prepared + """ - if check: - self._validate_local_file() + source.reader().read( + dataset=self.data.dir_name.lower(), + files=self.data.files, + local_ds_root=self.data_dir, + ) prefix = "shuffle_train" if self.data.use_shuffled else "train" self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')]) diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 46265297a..80c5ac1df 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -17,6 +17,7 @@ from ..metric import Metric from .runner import MultiProcessingSearchRunner from .runner import SerialSearchRunner, SerialInsertRunner +from .data_source import DatasetSource log = logging.getLogger(__name__) @@ -44,6 +45,7 @@ class CaseRunner(BaseModel): config: TaskConfig ca: Case status: RunningStatus + dataset_source: DatasetSource db: api.VectorDB | None = None test_emb: list[list[float]] | None = None @@ -59,7 +61,7 @@ def __eq__(self, obj): return False def display(self) -> dict: - c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': True} }) + c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} }) c_dict['db'] = self.config.db_name return c_dict @@ -82,7 +84,7 @@ def init_db(self, drop_old: bool = True) -> None: def _pre_run(self, drop_old: bool = True): try: self.init_db(drop_old) - self.ca.dataset.prepare() + self.ca.dataset.prepare(self.dataset_source) except ModuleNotFoundError as e: log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}") raise e from None diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index 96d6265a8..c170c67dc 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -23,6 +23,7 @@ from .backend.result_collector import ResultCollector from .backend.assembler import Assembler from .backend.task_runner import TaskRunner +from .backend.data_source import DatasetSource log = logging.getLogger(__name__) @@ -39,13 +40,16 @@ def __init__(self): self.running_task: TaskRunner | None = None self.latest_error: str | None = None self.drop_old: bool = True - + self.dataset_source: DatasetSource = DatasetSource.S3 + def set_drop_old(self, drop_old: bool): self.drop_old = drop_old - + def set_download_address(self, use_aliyun: bool): - # todo - pass + if use_aliyun: + self.dataset_source = DatasetSource.AliyunOSS + else: + self.dataset_source = DatasetSource.S3 def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: """run all the tasks in the configs, write one result into the path""" @@ -58,7 +62,7 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: log.warning("Empty tasks submitted") return False - log.debug(f"tasks: {tasks}") + log.debug(f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}") # Generate run_id run_id = uuid.uuid4().hex @@ -69,7 +73,7 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: self.latest_error = "" try: - self.running_task = Assembler.assemble_all(run_id, task_label, tasks) + self.running_task = Assembler.assemble_all(run_id, task_label, tasks, self.dataset_source) self.running_task.display() except ModuleNotFoundError as e: msg = f"Please install client for database, error={e}" diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index bbf3f4923..b923bdcd2 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,7 +1,6 @@ import logging from logging import config - def init(log_level): LOGGING = { 'version': 1,