-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
338 lines (304 loc) · 15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from collections.abc import Iterator
import numpy
import pyarrow.parquet
from google.cloud.storage import Bucket, Client, transfer_manager
import json
import pandas
import pathlib
from pinecone.grpc import PineconeGRPC
import pyarrow.dataset as ds
from pyarrow.parquet import ParquetDataset, ParquetFile
from filelock import FileLock
import vsb
from vsb import logger
from vsb.logging import ProgressIOWrapper
class Dataset:
"""
Represents a Dataset used as the source of documents and/or queries for
Vector Search operations.
The set of datasets are taken from the Pinecone public datasets
(https://docs.pinecone.io/docs/using-public-datasets), which reside in a
Google Cloud Storage bucket and are downloaded on-demand on first access,
then cached on the local machine.
"""
gcs_bucket = "pinecone-datasets-dev"
@staticmethod
def split_dataframe(df: pandas.DataFrame, batch_size) -> Iterator[pandas.DataFrame]:
for i in range(0, len(df), batch_size):
batch = df.iloc[i : i + batch_size]
yield batch
@staticmethod
def recall(actual_matches: list, expected_matches: list):
# Recall@K : how many relevant items were returned against how many
# relevant items exist in the entire dataset. Defined as:
# truePositives / (truePositives + falseNegatives)
# Handle degenerate case of zero matches.
if not actual_matches:
return 0
# To allow us to calculate Recall when the count of actual_matches from
# the query differs from expected_matches (e.g. when Query is
# executed with a top_k different to what the Dataset was built with),
# limit denominator to the minimum of the expected & actual.
# (This allows use to use a Dataset with say 100 exact nearest
# neighbours and still test the quality of results when querying at
# top_k==10 as-if only the 10 exact nearest neighbours had been
# provided).
relevant_size = min(len(actual_matches), len(expected_matches))
expected_matches = expected_matches[:relevant_size]
true_positives = len(set(expected_matches).intersection(set(actual_matches)))
recall = true_positives / relevant_size
return recall
def __init__(self, name: str = "", cache_dir: str = "", limit: int = 0):
self.name = name
self.cache = pathlib.Path(cache_dir)
self.limit = limit
self.queries = pandas.DataFrame()
@staticmethod
def list():
"""
List all available datasets on the GCS bucket.
:return: A list of dict objects, one for each dataset.
"""
client = Client.create_anonymous_client()
bucket: Bucket = client.bucket(Dataset.gcs_bucket)
metadata_blobs = bucket.list_blobs(match_glob="*/metadata.json")
datasets = []
for m in metadata_blobs:
datasets.append(json.loads(m.download_as_bytes()))
return datasets
def get_batch_iterator(
self, num_chunks: int, chunk_id: int, batch_size: int
) -> Iterator[pyarrow.RecordBatch]:
"""Split the dataset's documents into num_chunks of approximately the
same size, returning an Iterator over the Nth chunk which yields
batches of Records of at most batch_size.
:param num_chunks: Number of chunks to split the dataset into.
:param chunk_id: Which chunk to return an interator for.
:param batch_size: Preferred size of each batch returned.
"""
# If we are working with a complete dataset then we partition based
# on the set of files which make up the dataset, returning an iterator
# over the given chunk using pyarrow's Dataset.to_batches() API - this
# is memory-efficient as it only has to load each file / row group into
# memory at a time.
# If the dataset is not complete (has been limited to N rows), then we
# cannot use Dataset.to_batches() as it has no direct way to limit to N
# rows. Given specifying a limit is normally used for a significantly
# reduced subset of the dataset (e.g. 'test' variant) and hence memory
# usage _should_ be low, we implement in a different manner - read the
# first N rows into DataFrame, then split / iterate the dataframe.
assert chunk_id >= 0
assert chunk_id < num_chunks
self._download_dataset_files()
pq_files = list((self.cache / self.name).glob("passages/*.parquet"))
# If there are less files than num_chunks, there's no point in
# file-based partitioning (leads to users with no data).
if self.limit or len(pq_files) < num_chunks:
if len(pq_files) < num_chunks and chunk_id == 0:
logger.warning(
f"Requested users ({num_chunks}) is greater than number of parquet files ({len(pq_files)}) - "
f"will partition based on rows instead of files, load times will be significantly slower."
)
dset = ds.dataset(pq_files)
columns = self._get_set_of_passages_columns_to_read(dset)
if self.limit:
first_n = dset.head(self.limit, columns=columns)
else:
first_n = dset.to_table(columns=columns)
self.limit = first_n.num_rows
# Calculate start / end for this chunk, then split the table
# and create an iterator over it.
quotient, remainder = divmod(self.limit, num_chunks)
chunks = [quotient + (1 if r < remainder else 0) for r in range(num_chunks)]
# Determine start position based on sum of size of all chunks prior
# to ours.
start = sum(chunks[:chunk_id])
user_chunk = first_n.slice(offset=start, length=chunks[chunk_id])
def table_to_batches(table) -> Iterator[pyarrow.RecordBatch]:
for batch in table.to_batches(batch_size):
yield batch
return table_to_batches(user_chunk)
else:
# Need split the parquet files into `num_users` subset of files,
# then return a batch iterator over the `user_id`th subset.
chunks = numpy.array_split(pq_files, num_chunks)
my_chunks = list(chunks[chunk_id])
if not my_chunks:
# No chunks for this user - nothing to do.
return []
docs_pq_dataset = ds.dataset(my_chunks)
columns = self._get_set_of_passages_columns_to_read(docs_pq_dataset)
def files_to_batches(files: list):
"""Given a list of parquet files, return an iterator over
batches of the given size across all files.
"""
for f in files:
parquet = ParquetFile(f)
for batch in parquet.iter_batches(
columns=columns, batch_size=batch_size
):
yield batch
return files_to_batches(my_chunks)
def setup_queries(self, query_limit=0):
self._download_dataset_files()
self.queries = self._load_parquet_dataset("queries", limit=query_limit)
logger.debug(
f"Using {len(self.queries)} query vectors loaded from dataset 'queries' table"
)
def _download_dataset_files(self):
with FileLock(self.cache / ".lock"):
self.cache.mkdir(parents=True, exist_ok=True)
client = Client.create_anonymous_client()
bucket: Bucket = client.bucket(Dataset.gcs_bucket)
blobs = [b for b in bucket.list_blobs(prefix=self.name + "/")]
# Ignore directories (blobs ending in '/') as we don't explicilty need them
# (non-empty directories will have their files downloaded
# anyway).
blobs = [b for b in blobs if not b.name.endswith("/")]
def should_download(blob):
path = self.cache / blob.name
if not path.exists():
return True
# File exists - check size, assume same size is same file.
# (Ideally would check hash (md5), but using hashlib.md5() to
# calculate the local MD5 does not match remove; maybe due to
# transmission as compressed file?
local_size = path.stat().st_size
remote_size = blob.size
return local_size != remote_size
to_download = [b for b in filter(lambda b: should_download(b), blobs)]
if not to_download:
return
logger.debug(
f"Parquet dataset: downloading {len(to_download)} files belonging to "
f"dataset '{self.name}'"
)
vsb.progress = vsb.logging.make_progressbar()
with vsb.logging.progress_task(
" Downloading dataset files",
" ✔ Dataset download complete",
total=len(to_download),
) as download_task:
for blob in to_download:
logger.debug(
f"Dataset file '{blob.name}' not found in cache - will be downloaded"
)
dest_path = self.cache / blob.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
blob.download_to_file(
ProgressIOWrapper(
dest=dest_path,
progress=vsb.progress,
total=blob.size,
scale=1024 * 1024,
indent=2,
)
)
if vsb.progress:
vsb.progress.update(download_task, advance=1)
# Clear the progress bar now we're done.
vsb.progress.stop()
vsb.progress = None
def _load_parquet_dataset(self, kind, limit=0):
parquet_files = [f for f in (self.cache / self.name).glob(kind + "/*.parquet")]
if not len(parquet_files):
return pandas.DataFrame()
dataset = ParquetDataset(parquet_files)
# Read only the columns that Pinecone SDK makes use of.
if kind == "documents":
columns = ["id", "values", "sparse_values", "metadata"]
metadata_column = "metadata"
elif kind == "passages":
columns = self._get_set_of_passages_columns_to_read(dataset)
metadata_column = "metadata"
elif kind == "queries":
# 'queries' format which consists of query input parameters
# and expected results.
# * Required fields:
# - top_k
# - values or vector: dense search vector
# - ground truth nearest neighbours (stored in 'blob' field)
# * Optional fields:
# - id: query identifier.
# - sparse_vector: sparse search vector.
# - filter: metadata filter
fields = set(dataset.schema.names)
# Validate required fields are present.
required = set(["top_k", "blob"])
missing = required.difference(fields)
if len(missing) > 0:
raise ValueError(
f"Missing required fields ({missing}) for queries from dataset '{self.name}'"
)
value_field = set(["values", "vector"]).intersection(fields)
match len(value_field):
case 0:
raise ValueError(
f"Missing required search vector field ('values' or 'vector') queries from dataset '{self.name}'"
)
case 2:
raise ValueError(
f"Multiple search vector fields ('values' and 'vector') present in queries from dataset '{self.name}'"
)
case 1:
required = required | value_field
# Also load in supported optional fields.
optional = set(["id", "sparse_vector", "filter"])
columns = list(required.union((fields.intersection(optional))))
metadata_column = "filter"
else:
raise ValueError(
f"Unsupported kind '{kind}' - must be one of (documents, queries)"
)
# Note: We to specify pandas.ArrowDtype as the types mapper to use pyarrow datatypes in the
# resulting DataFrame. This is significant as (for reasons unknown) it allows subsequent
# samples() of the DataFrame to be "disconnected" from the original underlying pyarrow data,
# and hence significantly reduces memory usage when we later prune away the underlying
# parrow data (see prune_documents).
df = dataset.read(columns=columns).to_pandas(types_mapper=pandas.ArrowDtype)
if limit:
df = df.iloc[:limit]
# And drop any columns which all values are missing - e.g. not all
# datasets have sparse_values, but the parquet file may still have
# the (empty) column present.
df.dropna(axis="columns", how="all", inplace=True)
if metadata_column in df:
def cleanup_null_values(metadata):
# Null metadata values are not supported, remove any key
# will a null value.
if not metadata:
return None
return {k: v for k, v in metadata.items() if v}
def convert_metadata_to_dict(metadata) -> dict:
# metadata is expected to be a dictionary of key-value pairs;
# however it may be encoded as a JSON string in which case we
# need to convert it.
if metadata is None:
return None
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
return json.loads(metadata)
raise TypeError(
f"metadata must be a string or dict (found {type(metadata)})"
)
def prepare_metadata(metadata):
return cleanup_null_values(convert_metadata_to_dict(metadata))
df[metadata_column] = df[metadata_column].apply(prepare_metadata)
logger.debug(f"Loaded {len(df)} vectors of kind '{kind}'")
return df
def _get_set_of_passages_columns_to_read(self, dset: ds.Dataset):
# 'passages' format used by benchmarking datasets (e.g. mnist,
# nq-769-tasb, yfcc, ...) always have 'id' and 'values' fields;
# may optionally have `sparse_values` and `metadata`.
# Validate required fields are present.
required = set(["id", "values"])
fields = set(dset.schema.names)
missing = required.difference(fields)
if len(missing) > 0:
raise ValueError(
f"Missing required fields ({missing}) for passages from dataset '{self.name}'"
)
# Also load in supported optional fields.
optional = set(["sparse_values", "metadata"])
return list(required.union((fields.intersection(optional))))