Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorflow_datasets/core/file_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def beam_sink(
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
file_path_prefix = filename_template.sharded_filepaths_pattern(
num_shards=num_shards, use_at_notation=True
# num_shards cannot be both in the path and passed as an argument, so
# make sure it's not in the path.
num_shards=None,
use_at_notation=True,
).removesuffix('@*')
return beam.io.WriteToTFRecord(
file_path_prefix=file_path_prefix, num_shards=num_shards
Expand Down
35 changes: 35 additions & 0 deletions tensorflow_datasets/core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import collections
from collections.abc import Iterable, Iterator, Sequence
import concurrent.futures # pylint: disable=unused-import
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -584,3 +585,37 @@ def publish_data(
to_data_dir.mkdir(parents=True, exist_ok=True)
for filepath in from_data_dir.iterdir():
filepath.copy(dst=to_data_dir / filepath.name, overwrite=overwrite)


def bulk_rename(
old_paths: Sequence[epath.PathLike], new_paths: Sequence[epath.PathLike]
) -> None:
"""Renames a sequence of paths in bulk."""
if len(old_paths) != len(new_paths):
raise ValueError(
'old_paths and new_paths must have the same length, but got'
f' {len(old_paths)} and {len(new_paths)}'
)
for old_path, new_path in zip(old_paths, new_paths):
if old_path == new_path:
raise ValueError(
'old_paths and new_paths must not be the same, but got'
f' {old_path} and {new_path}'
)
if not old_paths:
return
def _rename(old_and_new_paths: tuple[epath.PathLike, epath.PathLike]):
old_path, new_path = old_and_new_paths
epath.Path(old_path).rename(new_path)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
executor.map(_rename, zip(old_paths, new_paths))


def bulk_delete(paths: Sequence[epath.PathLike]) -> None:
"""Deletes a sequence of paths in bulk."""
if not paths:
return
def _delete(path):
epath.Path(path).unlink()
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
executor.map(_delete, paths)
38 changes: 38 additions & 0 deletions tensorflow_datasets/core/utils/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,5 +423,43 @@ def test_publish_data(mock_fs: testing.MockFs):
assert mock_fs.read_file(to_data_dir / filename) == content


class BulkOperationTest(testing.TestCase):

def test_bulk_delete(self):
tmp_dir = epath.Path(self.tmp_dir)
file_1 = tmp_dir / 'a'
file_2 = tmp_dir / 'b'
file_1.touch()
file_2.touch()
file_utils.bulk_delete([file_1, file_2])
self.assertFalse(file_1.exists())
self.assertFalse(file_2.exists())

def test_bulk_rename(self):
tmp_dir = epath.Path(self.tmp_dir)
orig_files = [tmp_dir / f'src{i}' for i in range(10)]
dst_files = [tmp_dir / f'dst{i}' for i in range(10)]
for file in orig_files:
file.touch()
file_utils.bulk_rename(old_paths=orig_files, new_paths=dst_files)
for file in orig_files:
self.assertFalse(file.exists())
for file in dst_files:
self.assertTrue(file.exists())

def test_bulk_rename_with_different_number_of_files(self):
tmp_dir = epath.Path(self.tmp_dir)
orig_files = [tmp_dir / f'src{i}' for i in range(10)]
dst_files = [tmp_dir / f'dst{i}' for i in range(5)]
with self.assertRaises(ValueError):
file_utils.bulk_rename(old_paths=orig_files, new_paths=dst_files)

def test_bulk_rename_with_same_old_and_new_paths(self):
tmp_dir = epath.Path(self.tmp_dir)
orig_files = [tmp_dir / f'src{i}' for i in range(10)]
with self.assertRaises(ValueError):
file_utils.bulk_rename(old_paths=orig_files, new_paths=orig_files)


if __name__ == '__main__':
testing.test_main()
33 changes: 28 additions & 5 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,10 +805,6 @@ def write_from_pcollection(self, examples_pcollection):
| "Shuffle" >> beam.Reshuffle()
| "Serialize" >> beam.Map(self._serialize_example)
)
if self._num_shards is not None:
serialized_examples = serialized_examples | "Reshard" >> beam.Reshuffle(
self._num_shards
)
return serialized_examples | "Write" >> self._file_adapter.beam_sink(
filename_template=self._filename_template, num_shards=self._num_shards
)
Expand All @@ -833,4 +829,31 @@ def finalize(self) -> tuple[list[int], int]:
total_size_bytes,
)

return shard_lengths, total_size_bytes
# Empty shards may be produced by Beam. We delete them and rename the
# non-empty shards accordingly.
all_shard_paths = self._filename_template.sharded_filepaths(
len(shard_lengths)
)
non_empty_shards: list[epath.Path] = []
non_empty_shard_lengths: list[int] = []
empty_shards: list[epath.Path] = []
for length, shard_path in zip(shard_lengths, all_shard_paths):
if length > 0:
non_empty_shard_lengths.append(length)
non_empty_shards.append(shard_path)
else:
empty_shards.append(shard_path)

non_empty_shard_paths = self._filename_template.sharded_filepaths(
len(non_empty_shards)
)
if empty_shards:
old_paths: list[epath.Path] = []
new_paths: list[epath.Path] = []
for orig_path, new_path in zip(non_empty_shards, non_empty_shard_paths):
old_paths.append(orig_path)
new_paths.append(new_path)
file_utils.bulk_delete(empty_shards)
file_utils.bulk_rename(old_paths=old_paths, new_paths=new_paths)

return non_empty_shard_lengths, total_size_bytes
70 changes: 48 additions & 22 deletions tensorflow_datasets/core/writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from tensorflow_datasets.core.utils import shard_utils


FileFormat = file_adapters.FileFormat


class GetShardSpecsTest(testing.TestCase):
# Here we don't need to test all possible reading configs, as this is tested
# by shard_utils.py.
Expand Down Expand Up @@ -586,35 +589,44 @@ def test_write_tfrecord_sorted_by_key_with_holes(self):
class NoShuffleBeamWriterTest(parameterized.TestCase):

@parameterized.named_parameters(
('tfrecord', file_adapters.FileFormat.TFRECORD),
('tfrecord', FileFormat.TFRECORD, 10, None),
('tfrecord_1shard', FileFormat.TFRECORD, 10, 1),
('tfrecord_2shards', FileFormat.TFRECORD, 10, 2),
('tfrecord_more_shards_than_examples', FileFormat.TFRECORD, 10, 20),
)
def test_write_beam(self, file_format: file_adapters.FileFormat):
def test_write_beam(
self,
file_format: FileFormat,
num_examples: int,
num_shards: int | None,
):

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = epath.Path(tmp_dir)
splits = ['train-b', 'train']
filename_template = naming.ShardedFileTemplate(
dataset_name='foo',
filetype_suffix=file_format.file_suffix,
data_dir=tmp_dir,
)

def get_writer(split):
filename_template = naming.ShardedFileTemplate(
dataset_name='foo',
split=split,
filetype_suffix=file_format.file_suffix,
data_dir=tmp_dir,
)
return writer_lib.NoShuffleBeamWriter(
serializer=testing.DummySerializer('dummy specs'),
filename_template=filename_template,
filename_template=filename_template.replace(split=split),
file_format=file_format,
num_shards=num_shards,
)

to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
to_write = [(i, str(i).encode('utf-8')) for i in range(num_examples)]
# Here we need to disable type check as `beam.Create` is not capable of
# inferring the type of the PCollection elements.
options = beam.options.pipeline_options.PipelineOptions(
pipeline_type_check=False
)
writers = [get_writer(split) for split in ('train-b', 'train')]
writers = {split: get_writer(split) for split in splits}

for writer in writers:
for writer in writers.values():
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:

@beam.ptransform_fn
Expand All @@ -624,21 +636,35 @@ def _build_pcollection(pipeline, writer):

_ = pipeline | 'test' >> _build_pcollection(writer)

files = list(tmp_dir.iterdir())
self.assertGreaterEqual(len(files), 2)
for f in files:
self.assertIn(file_format.file_suffix, f.name)
for writer in writers:
# Check all writers have the correct shard lengths and total size.
for split, writer in writers.items():
shard_lengths, total_size = writer.finalize()
files = sorted([f.name for f in tmp_dir.iterdir()])

actual_num_shards = len(shard_lengths)
self.assertNotEmpty(shard_lengths)
self.assertEqual(sum(shard_lengths), 10)
self.assertGreater(total_size, 10)
if num_shards is not None:
self.assertLessEqual(actual_num_shards, num_shards)

self.assertEqual(sum(shard_lengths), num_examples)
self.assertGreater(total_size, num_examples)

# Make sure that no shard is empty.
self.assertNotIn(0, shard_lengths)

# Make sure that all shards are present.
template = filename_template.replace(split=split)
shard_paths = [
f.name
for f in template.sharded_filepaths(num_shards=actual_num_shards)
]
self.assertContainsSubset(shard_paths, files)


class CustomExampleWriter(writer_lib.ExampleWriter):

def __init__(self):
super().__init__(file_adapters.FileFormat.TFRECORD)
super().__init__(FileFormat.TFRECORD)
self.num_examples_written = 0

def write(self, path, examples) -> file_adapters.ExamplePositions | None:
Expand All @@ -650,10 +676,10 @@ class ExampleWriterTest(parameterized.TestCase):

def test_multi_output_example_writer(self):
tfrecord_writer = mock.create_autospec(writer_lib.ExampleWriter)
tfrecord_writer.file_format = file_adapters.FileFormat.TFRECORD
tfrecord_writer.file_format = FileFormat.TFRECORD

riegeli_writer = mock.create_autospec(writer_lib.ExampleWriter)
riegeli_writer.file_format = file_adapters.FileFormat.RIEGELI
riegeli_writer.file_format = FileFormat.RIEGELI

path = '/tmp/dataset-train.tfrecord-00000-of-00001'
iterator = [
Expand Down
Loading