Skip to content

Commit fd2b186

Browse files
XuanYang-cnalwayslove2013
authored andcommitted
Enable aliyun OSS
Add data_source.py, vdb bench now can download dataset from Aliyun OSS. Signed-off-by: yangxuan <[email protected]>
1 parent 34e5794 commit fd2b186

File tree

10 files changed

+418
-136
lines changed

10 files changed

+418
-136
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"streamlit_extras",
2727
"tqdm",
2828
"s3fs",
29+
"oss2",
2930
"psutil",
3031
"polars",
3132
"plotly",

tests/test_data_source.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import logging
2+
import pathlib
3+
import pytest
4+
from vectordb_bench.backend.data_source import AliyunOSSReader, AwsS3Reader
5+
from vectordb_bench.backend.dataset import Dataset, DatasetManager
6+
7+
log = logging.getLogger(__name__)
8+
9+
class TestReader:
10+
@pytest.mark.parametrize("size", [
11+
100_000,
12+
1_000_000,
13+
10_000_000,
14+
])
15+
def test_cohere(self, size):
16+
cohere = Dataset.COHERE.manager(size)
17+
self.per_dataset_test(cohere)
18+
19+
@pytest.mark.parametrize("size", [
20+
100_000,
21+
1_000_000,
22+
])
23+
def test_gist(self, size):
24+
gist = Dataset.GIST.manager(size)
25+
self.per_dataset_test(gist)
26+
27+
@pytest.mark.parametrize("size", [
28+
1_000_000,
29+
])
30+
def test_glove(self, size):
31+
glove = Dataset.GLOVE.manager(size)
32+
self.per_dataset_test(glove)
33+
34+
@pytest.mark.parametrize("size", [
35+
500_000,
36+
5_000_000,
37+
# 50_000_000,
38+
])
39+
def test_sift(self, size):
40+
sift = Dataset.SIFT.manager(size)
41+
self.per_dataset_test(sift)
42+
43+
@pytest.mark.parametrize("size", [
44+
50_000,
45+
500_000,
46+
5_000_000,
47+
])
48+
def test_openai(self, size):
49+
openai = Dataset.OPENAI.manager(size)
50+
self.per_dataset_test(openai)
51+
52+
53+
def per_dataset_test(self, dataset: DatasetManager):
54+
s3_reader = AwsS3Reader()
55+
all_files = s3_reader.ls_all(dataset.data.dir_name)
56+
57+
58+
remote_f_names = []
59+
for file in all_files:
60+
remote_f = pathlib.Path(file).name
61+
if dataset.data.use_shuffled and remote_f.startswith("train"):
62+
continue
63+
64+
elif (not dataset.data.use_shuffled) and remote_f.startswith("shuffle"):
65+
continue
66+
67+
remote_f_names.append(remote_f)
68+
69+
70+
assert set(dataset.data.files) == set(remote_f_names)
71+
72+
aliyun_reader = AliyunOSSReader()
73+
for fname in dataset.data.files:
74+
p = pathlib.Path("benchmark", dataset.data.dir_name, fname)
75+
assert aliyun_reader.bucket.object_exists(p.as_posix())
76+
77+
log.info(f"downloading to {dataset.data_dir}")
78+
aliyun_reader.read(dataset.data.dir_name.lower(), dataset.data.files, dataset.data_dir)

tests/test_dataset.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from vectordb_bench.backend.dataset import Dataset
1+
from vectordb_bench.backend.dataset import Dataset, get_files
22
import logging
33
import pytest
44
from pydantic import ValidationError
@@ -34,3 +34,34 @@ def test_iter_cohere(self):
3434
for i in cohere_10m:
3535
log.debug(i.head(1))
3636

37+
38+
class TestGetFiles:
39+
@pytest.mark.parametrize("train_count", [
40+
1,
41+
10,
42+
50,
43+
100,
44+
])
45+
@pytest.mark.parametrize("with_gt", [True, False])
46+
def test_train_count(self, train_count, with_gt):
47+
files = get_files(train_count, True, with_gt)
48+
log.info(files)
49+
50+
if with_gt:
51+
assert len(files) - 4 == train_count
52+
else:
53+
assert len(files) - 1 == train_count
54+
55+
@pytest.mark.parametrize("use_shuffled", [True, False])
56+
def test_use_shuffled(self, use_shuffled):
57+
files = get_files(1, use_shuffled, True)
58+
log.info(files)
59+
60+
trains = [f for f in files if "train" in f]
61+
if use_shuffled:
62+
for t in trains:
63+
assert "shuffle_train" in t
64+
else:
65+
for t in trains:
66+
assert "shuffle" not in t
67+
assert "train" in t

vectordb_bench/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
env.read_env(".env")
99

1010
class config:
11+
ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/"
12+
AWS_S3_URL = "assets.zilliz.com/benchmark/"
13+
1114
LOG_LEVEL = env.str("LOG_LEVEL", "INFO")
1215

13-
DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com/benchmark/")
14-
DEFAULT_DATASET_URL_ALIYUN = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com.cn/benchmark/")
16+
DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", AWS_S3_URL)
1517
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
1618
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000)
1719

vectordb_bench/backend/assembler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .task_runner import CaseRunner, RunningStatus, TaskRunner
33
from ..models import TaskConfig
44
from ..backend.clients import EmptyDBCaseConfig
5+
from ..backend.data_source import DatasetSource
56
import logging
67

78

@@ -10,7 +11,7 @@
1011

1112
class Assembler:
1213
@classmethod
13-
def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
14+
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
1415
c_cls = task.case_config.case_id.case_cls
1516

1617
c = c_cls()
@@ -22,14 +23,21 @@ def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
2223
config=task,
2324
ca=c,
2425
status=RunningStatus.PENDING,
26+
dataset_source=source,
2527
)
2628

2729
return runner
2830

2931
@classmethod
30-
def assemble_all(cls, run_id: str, task_label: str, tasks: list[TaskConfig]) -> TaskRunner:
32+
def assemble_all(
33+
cls,
34+
run_id: str,
35+
task_label: str,
36+
tasks: list[TaskConfig],
37+
source: DatasetSource,
38+
) -> TaskRunner:
3139
"""group by case type, db, and case dataset"""
32-
runners = [cls.assemble(run_id, task) for task in tasks]
40+
runners = [cls.assemble(run_id, task, source) for task in tasks]
3341
load_runners = [r for r in runners if r.ca.label == CaseLabel.Load]
3442
perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance]
3543

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import logging
2+
import pathlib
3+
import typing
4+
from enum import Enum
5+
from tqdm import tqdm
6+
from hashlib import md5
7+
import os
8+
from abc import ABC, abstractmethod
9+
10+
from .. import config
11+
12+
logging.getLogger("s3fs").setLevel(logging.CRITICAL)
13+
14+
log = logging.getLogger(__name__)
15+
16+
DatasetReader = typing.TypeVar("DatasetReader")
17+
18+
class DatasetSource(Enum):
19+
S3 = "S3"
20+
AliyunOSS = "AliyunOSS"
21+
22+
def reader(self) -> DatasetReader:
23+
if self == DatasetSource.S3:
24+
return AwsS3Reader()
25+
26+
if self == DatasetSource.AliyunOSS:
27+
return AliyunOSSReader()
28+
29+
30+
class DatasetReader(ABC):
31+
source: DatasetSource
32+
remote_root: str
33+
34+
@abstractmethod
35+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
36+
"""read dataset files from remote_root to local_ds_root,
37+
38+
Args:
39+
dataset(str): for instance "sift_small_500k"
40+
files(list[str]): all filenames of the dataset
41+
local_ds_root(pathlib.Path): whether to write the remote data.
42+
check_etag(bool): whether to check the etag
43+
"""
44+
pass
45+
46+
@abstractmethod
47+
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
48+
pass
49+
50+
51+
class AliyunOSSReader(DatasetReader):
52+
source: DatasetSource = DatasetSource.AliyunOSS
53+
remote_root: str = config.ALIYUN_OSS_URL
54+
55+
def __init__(self):
56+
import oss2
57+
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)
58+
59+
def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
60+
info = self.bucket.get_object_meta(remote.as_posix())
61+
62+
# check size equal
63+
remote_size, local_size = info.content_length, os.path.getsize(local)
64+
if remote_size != local_size:
65+
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
66+
return False
67+
68+
# check etag equal
69+
if check_etag:
70+
return match_etag(info.etag.strip('"').lower(), local)
71+
72+
73+
return True
74+
75+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False):
76+
downloads = []
77+
if not local_ds_root.exists():
78+
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
79+
local_ds_root.mkdir(parents=True)
80+
downloads = [(pathlib.Path("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]
81+
82+
else:
83+
for file in files:
84+
remote_file = pathlib.Path("benchmark", dataset, file)
85+
local_file = local_ds_root.joinpath(file)
86+
87+
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
88+
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
89+
downloads.append((remote_file, local_file))
90+
91+
if len(downloads) == 0:
92+
return
93+
94+
log.info(f"Start to downloading files, total count: {len(downloads)}")
95+
for remote_file, local_file in tqdm(downloads):
96+
log.debug(f"downloading file {remote_file} to {local_ds_root}")
97+
self.bucket.get_object_to_file(remote_file.as_posix(), local_file.as_posix())
98+
99+
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
100+
101+
102+
103+
class AwsS3Reader(DatasetReader):
104+
source: DatasetSource = DatasetSource.S3
105+
remote_root: str = config.AWS_S3_URL
106+
107+
def __init__(self):
108+
import s3fs
109+
self.fs = s3fs.S3FileSystem(
110+
anon=True,
111+
client_kwargs={'region_name': 'us-west-2'}
112+
)
113+
114+
def ls_all(self, dataset: str):
115+
dataset_root_dir = pathlib.Path(self.remote_root, dataset)
116+
log.info(f"listing dataset: {dataset_root_dir}")
117+
names = self.fs.ls(dataset_root_dir)
118+
for n in names:
119+
log.info(n)
120+
return names
121+
122+
123+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
124+
downloads = []
125+
if not local_ds_root.exists():
126+
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
127+
local_ds_root.mkdir(parents=True)
128+
downloads = [pathlib.Path(self.remote_root, dataset, f) for f in files]
129+
130+
else:
131+
for file in files:
132+
remote_file = pathlib.Path(self.remote_root, dataset, file)
133+
local_file = local_ds_root.joinpath(file)
134+
135+
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
136+
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
137+
downloads.append(remote_file)
138+
139+
if len(downloads) == 0:
140+
return
141+
142+
log.info(f"Start to downloading files, total count: {len(downloads)}")
143+
for s3_file in tqdm(downloads):
144+
log.debug(f"downloading file {s3_file} to {local_ds_root}")
145+
self.fs.download(s3_file, local_ds_root.as_posix())
146+
147+
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
148+
149+
150+
def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
151+
# info() uses ls() inside, maybe we only need to ls once
152+
info = self.fs.info(remote)
153+
154+
# check size equal
155+
remote_size, local_size = info.get("size"), os.path.getsize(local)
156+
if remote_size != local_size:
157+
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
158+
return False
159+
160+
# check etag equal
161+
if check_etag:
162+
return match_etag(info.get('ETag', "").strip('"'), local)
163+
164+
return True
165+
166+
167+
def match_etag(expected_etag: str, local_file) -> bool:
168+
"""Check if local files' etag match with S3"""
169+
def factor_of_1MB(filesize, num_parts):
170+
x = filesize / int(num_parts)
171+
y = x % 1048576
172+
return int(x + 1048576 - y)
173+
174+
def calc_etag(inputfile, partsize):
175+
md5_digests = []
176+
with open(inputfile, 'rb') as f:
177+
for chunk in iter(lambda: f.read(partsize), b''):
178+
md5_digests.append(md5(chunk).digest())
179+
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))
180+
181+
def possible_partsizes(filesize, num_parts):
182+
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts
183+
184+
filesize = os.path.getsize(local_file)
185+
le = ""
186+
if '-' not in expected_etag: # no spliting uploading
187+
with open(local_file, 'rb') as f:
188+
le = md5(f.read()).hexdigest()
189+
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
190+
return expected_etag == le
191+
else:
192+
num_parts = int(expected_etag.split('-')[-1])
193+
partsizes = [ ## Default Partsizes Map
194+
8388608, # aws_cli/boto3
195+
15728640, # s3cmd
196+
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
197+
]
198+
199+
for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
200+
le = calc_etag(local_file, partsize)
201+
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
202+
if expected_etag == le:
203+
return True
204+
return False

0 commit comments

Comments
 (0)