Skip to content

Commit 28da673

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Allow passing reader option through ArrayRecordDataSource
PiperOrigin-RevId: 770253896
1 parent f028ebe commit 28da673

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

python/array_record_data_source.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ def get_read_instruction(path: PathLikeOrFileInstruction) -> _ReadInstruction:
173173
)
174174

175175

176-
def _create_reader(filename: epath.PathLike):
176+
def _create_reader(filename: epath.PathLike, additional_reader_options: str):
177177
"""Returns an ArrayRecordReader for the given filename."""
178+
reader_options = f"readahead_buffer_size:0,{additional_reader_options}"
178179
return array_record_module.ArrayRecordReader(
179180
filename,
180-
options="readahead_buffer_size:0",
181+
options=reader_options,
181182
file_reader_buffer_size=32768,
182183
)
183184

@@ -219,6 +220,7 @@ def __init__(
219220
paths: Union[
220221
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
221222
],
223+
reader_options: dict[str, str] | None = None,
222224
):
223225
"""Creates a new ArrayRecordDataSource object.
224226
@@ -238,6 +240,8 @@ def __init__(
238240
paths/FileInstructions. When you want to read subsets or have a large
239241
number of files prefer to pass FileInstructions. This makes the
240242
initialization faster.
243+
reader_options: string of comma-separated options to be passed when
244+
creating a reader.
241245
"""
242246
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
243247
paths = [paths]
@@ -258,6 +262,12 @@ def __init__(
258262
"Unsupported path format was used. Path format must be "
259263
"a Sequence, String, pathlib.Path or FileInstruction."
260264
)
265+
if reader_options is None:
266+
self._reader_options_string = ""
267+
else:
268+
self._reader_options_string = ",".join(
269+
[f"{k}:{v}" for k, v in reader_options.items()]
270+
)
261271
self._read_instructions = _get_read_instructions(paths)
262272
self._paths = [ri.filename for ri in self._read_instructions]
263273
# We open readers lazily when we need to read from them.
@@ -324,7 +334,7 @@ def _ensure_reader_exists(self, reader_idx: int) -> None:
324334
if self._readers[reader_idx] is not None:
325335
return
326336
filename = self._read_instructions[reader_idx].filename
327-
reader = _create_reader(filename)
337+
reader = _create_reader(filename, self._reader_options_string)
328338
_check_group_size(filename, reader)
329339
self._readers[reader_idx] = reader
330340

python/array_record_data_source_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,22 @@ def test_repr(self):
247247
])
248248
self.assertRegex(repr(ar), r"ArrayRecordDataSource\(hash_of_paths=[\w]+\)")
249249

250+
@flagsaver.flagsaver(grain_use_fast_array_record_reader=False)
251+
def test_additional_reader_options(self):
252+
indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6]
253+
ar = array_record_data_source.ArrayRecordDataSource(
254+
[
255+
self.testdata_dir / "digits.array_record-00000-of-00002",
256+
self.testdata_dir / "digits.array_record-00001-of-00002",
257+
],
258+
{"index_storage_option": "in_memory"},
259+
)
260+
# We need to read the records to trigger the creation of the readers.
261+
_ = [ar[x] for x in indices_to_read]
262+
self.assertLen(ar._readers, 2)
263+
self.assertIsInstance(ar._readers[0], array_record_module.ArrayRecordReader)
264+
self.assertIsInstance(ar._readers[1], array_record_module.ArrayRecordReader)
265+
250266

251267
class RunInParallelTest(parameterized.TestCase):
252268

0 commit comments

Comments
 (0)