From 43302a9a6576df4393d6b0aa276de9a3d01ac17f Mon Sep 17 00:00:00 2001 From: yangxuan Date: Wed, 13 Mar 2024 17:35:21 +0800 Subject: [PATCH] enhance: Remove etag checks Etag checks cost a lot, check the size, file_name, and, file count is enough Signed-off-by: yangxuan --- tests/test_data_source.py | 4 +- tests/test_dataset.py | 6 +-- vectordb_bench/backend/data_source.py | 65 +++------------------------ vectordb_bench/backend/dataset.py | 3 -- 4 files changed, 11 insertions(+), 67 deletions(-) diff --git a/tests/test_data_source.py b/tests/test_data_source.py index 07e5b1878..302ee3a1c 100644 --- a/tests/test_data_source.py +++ b/tests/test_data_source.py @@ -19,10 +19,10 @@ def per_case_test(self, type_case): log.info(f"test case: {t.name}, {ca.name}") filters = ca.filter_rate - ca.dataset.prepare(source=DatasetSource.AliyunOSS, check=False, filters=filters) + ca.dataset.prepare(source=DatasetSource.AliyunOSS, filters=filters) ali_trains = ca.dataset.train_files - ca.dataset.prepare(check=False, filters=filters) + ca.dataset.prepare(filters=filters) s3_trains = ca.dataset.train_files assert ali_trains == s3_trains diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c7678c206..d4ccb283d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -26,7 +26,7 @@ def test_cohere_error(self): def test_iter_cohere(self): cohere_10m = Dataset.COHERE.manager(10_000_000) - cohere_10m.prepare(check=False) + cohere_10m.prepare() import time before = time.time() @@ -40,7 +40,7 @@ def test_iter_cohere(self): def test_iter_laion(self): laion_100m = Dataset.LAION.manager(100_000_000) from vectordb_bench.backend.data_source import DatasetSource - laion_100m.prepare(source=DatasetSource.AliyunOSS, check=False) + laion_100m.prepare(source=DatasetSource.AliyunOSS) import time before = time.time() @@ -66,7 +66,6 @@ def test_download_small(self): openai_50k.data.dir_name.lower(), files=files, local_ds_root=openai_50k.data_dir, - check_etag=False, ) os.remove(file_path) @@ -74,6 +73,5 @@ def test_download_small(self): openai_50k.data.dir_name.lower(), files=files, local_ds_root=openai_50k.data_dir, - check_etag=False, ) diff --git a/vectordb_bench/backend/data_source.py b/vectordb_bench/backend/data_source.py index 28e3c3636..9e2f172b4 100644 --- a/vectordb_bench/backend/data_source.py +++ b/vectordb_bench/backend/data_source.py @@ -3,7 +3,6 @@ import typing from enum import Enum from tqdm import tqdm -from hashlib import md5 import os from abc import ABC, abstractmethod @@ -32,14 +31,13 @@ class DatasetReader(ABC): remote_root: str @abstractmethod - def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): """read dataset files from remote_root to local_ds_root, Args: dataset(str): for instance "sift_small_500k" files(list[str]): all filenames of the dataset local_ds_root(pathlib.Path): whether to write the remote data. - check_etag(bool): whether to check the etag """ pass @@ -56,7 +54,7 @@ def __init__(self): import oss2 self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True) - def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: + def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: info = self.bucket.get_object_meta(remote.as_posix()) # check size equal @@ -65,13 +63,9 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") return False - # check etag equal - if check_etag: - return match_etag(info.etag.strip('"').lower(), local) - return True - def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False): + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): downloads = [] if not local_ds_root.exists(): log.info(f"local dataset root path not exist, creating it: {local_ds_root}") @@ -83,8 +77,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec remote_file = pathlib.PurePosixPath("benchmark", dataset, file) local_file = local_ds_root.joinpath(file) - # Don't check etags for Dataset from Aliyun OSS - if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, False)): + if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") downloads.append((remote_file, local_file)) @@ -120,7 +113,7 @@ def ls_all(self, dataset: str): return names - def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): + def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): downloads = [] if not local_ds_root.exists(): log.info(f"local dataset root path not exist, creating it: {local_ds_root}") @@ -132,7 +125,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec remote_file = pathlib.PurePosixPath(self.remote_root, dataset, file) local_file = local_ds_root.joinpath(file) - if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)): + if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") downloads.append(remote_file) @@ -147,7 +140,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") - def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: + def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: # info() uses ls() inside, maybe we only need to ls once info = self.fs.info(remote) @@ -157,48 +150,4 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") return False - # check etag equal - if check_etag: - return match_etag(info.get('ETag', "").strip('"'), local) - return True - - -def match_etag(expected_etag: str, local_file) -> bool: - """Check if local files' etag match with S3""" - def factor_of_1MB(filesize, num_parts): - x = filesize / int(num_parts) - y = x % 1048576 - return int(x + 1048576 - y) - - def calc_etag(inputfile, partsize): - md5_digests = [] - with open(inputfile, 'rb') as f: - for chunk in iter(lambda: f.read(partsize), b''): - md5_digests.append(md5(chunk).digest()) - return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)) - - def possible_partsizes(filesize, num_parts): - return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts - - filesize = os.path.getsize(local_file) - le = "" - if '-' not in expected_etag: # no spliting uploading - with open(local_file, 'rb') as f: - le = md5(f.read()).hexdigest() - log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") - return expected_etag == le - else: - num_parts = int(expected_etag.split('-')[-1]) - partsizes = [ ## Default Partsizes Map - 8388608, # aws_cli/boto3 - 15728640, # s3cmd - factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files - ] - - for partsize in filter(possible_partsizes(filesize, num_parts), partsizes): - le = calc_etag(local_file, partsize) - log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") - if expected_etag == le: - return True - return False diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index ffe5b10f0..2b630eae3 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -162,7 +162,6 @@ def __iter__(self): # TODO passing use_shuffle from outside def prepare(self, source: DatasetSource=DatasetSource.S3, - check: bool=True, filters: int | float | str | None = None, ) -> bool: """Download the dataset from DatasetSource @@ -170,7 +169,6 @@ def prepare(self, Args: source(DatasetSource): S3 or AliyunOSS, default as S3 - check(bool): Whether to do etags check, default as ture filters(Optional[int | float | str]): combined with dataset's with_gt to compose the correct ground_truth file @@ -192,7 +190,6 @@ def prepare(self, dataset=self.data.dir_name.lower(), files=all_files, local_ds_root=self.data_dir, - check_etag=check, ) if gt_file is not None and test_file is not None: