Skip to content

Commit

Permalink
enhance: Remove etag checks
Browse files Browse the repository at this point in the history
Etag checks cost a lot, check the size, file_name,
and, file count is enough

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn committed Mar 13, 2024
1 parent 22123d0 commit be77ccd
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 67 deletions.
4 changes: 2 additions & 2 deletions tests/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def per_case_test(self, type_case):
log.info(f"test case: {t.name}, {ca.name}")

filters = ca.filter_rate
ca.dataset.prepare(source=DatasetSource.AliyunOSS, check=False, filters=filters)
ca.dataset.prepare(source=DatasetSource.AliyunOSS, filters=filters)
ali_trains = ca.dataset.train_files

ca.dataset.prepare(check=False, filters=filters)
ca.dataset.prepare(filters=filters)
s3_trains = ca.dataset.train_files

assert ali_trains == s3_trains
6 changes: 2 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_cohere_error(self):

def test_iter_cohere(self):
cohere_10m = Dataset.COHERE.manager(10_000_000)
cohere_10m.prepare(check=False)
cohere_10m.prepare()

import time
before = time.time()
Expand All @@ -40,7 +40,7 @@ def test_iter_cohere(self):
def test_iter_laion(self):
laion_100m = Dataset.LAION.manager(100_000_000)
from vectordb_bench.backend.data_source import DatasetSource
laion_100m.prepare(source=DatasetSource.AliyunOSS, check=False)
laion_100m.prepare(source=DatasetSource.AliyunOSS)

import time
before = time.time()
Expand All @@ -66,14 +66,12 @@ def test_download_small(self):
openai_50k.data.dir_name.lower(),
files=files,
local_ds_root=openai_50k.data_dir,
check_etag=False,
)

os.remove(file_path)
DatasetSource.AliyunOSS.reader().read(
openai_50k.data.dir_name.lower(),
files=files,
local_ds_root=openai_50k.data_dir,
check_etag=False,
)

65 changes: 7 additions & 58 deletions vectordb_bench/backend/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
from enum import Enum
from tqdm import tqdm
from hashlib import md5
import os
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -32,14 +31,13 @@ class DatasetReader(ABC):
remote_root: str

@abstractmethod
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
"""read dataset files from remote_root to local_ds_root,
Args:
dataset(str): for instance "sift_small_500k"
files(list[str]): all filenames of the dataset
local_ds_root(pathlib.Path): whether to write the remote data.
check_etag(bool): whether to check the etag
"""
pass

Expand All @@ -56,7 +54,7 @@ def __init__(self):
import oss2
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)

def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
info = self.bucket.get_object_meta(remote.as_posix())

# check size equal
Expand All @@ -65,13 +63,9 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.etag.strip('"').lower(), local)

return True

def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
Expand All @@ -83,8 +77,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
remote_file = pathlib.PurePosixPath("benchmark", dataset, file)
local_file = local_ds_root.joinpath(file)

# Don't check etags for Dataset from Aliyun OSS
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, False)):
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append((remote_file, local_file))

Expand Down Expand Up @@ -120,7 +113,7 @@ def ls_all(self, dataset: str):
return names


def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
Expand All @@ -132,7 +125,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
remote_file = pathlib.PurePosixPath(self.remote_root, dataset, file)
local_file = local_ds_root.joinpath(file)

if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append(remote_file)

Expand All @@ -147,7 +140,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")


def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
# info() uses ls() inside, maybe we only need to ls once
info = self.fs.info(remote)

Expand All @@ -157,48 +150,4 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.get('ETag', "").strip('"'), local)

return True


def match_etag(expected_etag: str, local_file) -> bool:
"""Check if local files' etag match with S3"""
def factor_of_1MB(filesize, num_parts):
x = filesize / int(num_parts)
y = x % 1048576
return int(x + 1048576 - y)

def calc_etag(inputfile, partsize):
md5_digests = []
with open(inputfile, 'rb') as f:
for chunk in iter(lambda: f.read(partsize), b''):
md5_digests.append(md5(chunk).digest())
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))

def possible_partsizes(filesize, num_parts):
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts

filesize = os.path.getsize(local_file)
le = ""
if '-' not in expected_etag: # no spliting uploading
with open(local_file, 'rb') as f:
le = md5(f.read()).hexdigest()
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
return expected_etag == le
else:
num_parts = int(expected_etag.split('-')[-1])
partsizes = [ ## Default Partsizes Map
8388608, # aws_cli/boto3
15728640, # s3cmd
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
]

for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
le = calc_etag(local_file, partsize)
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
if expected_etag == le:
return True
return False
3 changes: 0 additions & 3 deletions vectordb_bench/backend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,13 @@ def __iter__(self):
# TODO passing use_shuffle from outside
def prepare(self,
source: DatasetSource=DatasetSource.S3,
check: bool=True,
filters: int | float | str | None = None,
) -> bool:
"""Download the dataset from DatasetSource
url = f"{source}/{self.data.dir_name}"
Args:
source(DatasetSource): S3 or AliyunOSS, default as S3
check(bool): Whether to do etags check, default as ture
filters(Optional[int | float | str]): combined with dataset's with_gt to
compose the correct ground_truth file
Expand All @@ -192,7 +190,6 @@ def prepare(self,
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
check_etag=check,
)

if gt_file is not None and test_file is not None:
Expand Down

0 comments on commit be77ccd

Please sign in to comment.