Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Remove etag checks #291

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def per_case_test(self, type_case):
log.info(f"test case: {t.name}, {ca.name}")

filters = ca.filter_rate
ca.dataset.prepare(source=DatasetSource.AliyunOSS, check=False, filters=filters)
ca.dataset.prepare(source=DatasetSource.AliyunOSS, filters=filters)
ali_trains = ca.dataset.train_files

ca.dataset.prepare(check=False, filters=filters)
ca.dataset.prepare(filters=filters)
s3_trains = ca.dataset.train_files

assert ali_trains == s3_trains
6 changes: 2 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_cohere_error(self):

def test_iter_cohere(self):
cohere_10m = Dataset.COHERE.manager(10_000_000)
cohere_10m.prepare(check=False)
cohere_10m.prepare()

import time
before = time.time()
Expand All @@ -40,7 +40,7 @@ def test_iter_cohere(self):
def test_iter_laion(self):
laion_100m = Dataset.LAION.manager(100_000_000)
from vectordb_bench.backend.data_source import DatasetSource
laion_100m.prepare(source=DatasetSource.AliyunOSS, check=False)
laion_100m.prepare(source=DatasetSource.AliyunOSS)

import time
before = time.time()
Expand All @@ -66,14 +66,12 @@ def test_download_small(self):
openai_50k.data.dir_name.lower(),
files=files,
local_ds_root=openai_50k.data_dir,
check_etag=False,
)

os.remove(file_path)
DatasetSource.AliyunOSS.reader().read(
openai_50k.data.dir_name.lower(),
files=files,
local_ds_root=openai_50k.data_dir,
check_etag=False,
)

65 changes: 7 additions & 58 deletions vectordb_bench/backend/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
from enum import Enum
from tqdm import tqdm
from hashlib import md5
import os
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -32,14 +31,13 @@ class DatasetReader(ABC):
remote_root: str

@abstractmethod
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
"""read dataset files from remote_root to local_ds_root,

Args:
dataset(str): for instance "sift_small_500k"
files(list[str]): all filenames of the dataset
local_ds_root(pathlib.Path): whether to write the remote data.
check_etag(bool): whether to check the etag
"""
pass

Expand All @@ -56,7 +54,7 @@ def __init__(self):
import oss2
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)

def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
info = self.bucket.get_object_meta(remote.as_posix())

# check size equal
Expand All @@ -65,13 +63,9 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.etag.strip('"').lower(), local)

return True

def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
Expand All @@ -83,8 +77,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
remote_file = pathlib.PurePosixPath("benchmark", dataset, file)
local_file = local_ds_root.joinpath(file)

# Don't check etags for Dataset from Aliyun OSS
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, False)):
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append((remote_file, local_file))

Expand Down Expand Up @@ -120,7 +113,7 @@ def ls_all(self, dataset: str):
return names


def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
Expand All @@ -132,7 +125,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
remote_file = pathlib.PurePosixPath(self.remote_root, dataset, file)
local_file = local_ds_root.joinpath(file)

if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append(remote_file)

Expand All @@ -147,7 +140,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, chec
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")


def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
# info() uses ls() inside, maybe we only need to ls once
info = self.fs.info(remote)

Expand All @@ -157,48 +150,4 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: b
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.get('ETag', "").strip('"'), local)

return True


def match_etag(expected_etag: str, local_file) -> bool:
"""Check if local files' etag match with S3"""
def factor_of_1MB(filesize, num_parts):
x = filesize / int(num_parts)
y = x % 1048576
return int(x + 1048576 - y)

def calc_etag(inputfile, partsize):
md5_digests = []
with open(inputfile, 'rb') as f:
for chunk in iter(lambda: f.read(partsize), b''):
md5_digests.append(md5(chunk).digest())
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))

def possible_partsizes(filesize, num_parts):
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts

filesize = os.path.getsize(local_file)
le = ""
if '-' not in expected_etag: # no spliting uploading
with open(local_file, 'rb') as f:
le = md5(f.read()).hexdigest()
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
return expected_etag == le
else:
num_parts = int(expected_etag.split('-')[-1])
partsizes = [ ## Default Partsizes Map
8388608, # aws_cli/boto3
15728640, # s3cmd
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
]

for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
le = calc_etag(local_file, partsize)
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
if expected_etag == le:
return True
return False
3 changes: 0 additions & 3 deletions vectordb_bench/backend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,13 @@ def __iter__(self):
# TODO passing use_shuffle from outside
def prepare(self,
source: DatasetSource=DatasetSource.S3,
check: bool=True,
filters: int | float | str | None = None,
) -> bool:
"""Download the dataset from DatasetSource
url = f"{source}/{self.data.dir_name}"

Args:
source(DatasetSource): S3 or AliyunOSS, default as S3
check(bool): Whether to do etags check, default as ture
filters(Optional[int | float | str]): combined with dataset's with_gt to
compose the correct ground_truth file

Expand All @@ -192,7 +190,6 @@ def prepare(self,
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
check_etag=check,
)

if gt_file is not None and test_file is not None:
Expand Down