Skip to content

Commit

Permalink
Move file_format parameter from _as_dataset to ReadConfig becau…
Browse files Browse the repository at this point in the history
…se this is breaking external code that overrides `_as_dataset`.

PiperOrigin-RevId: 704667761
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 10, 2024
1 parent 8807a05 commit 4721f7d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
52 changes: 22 additions & 30 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,6 @@ def as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
as_supervised: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
):
# pylint: disable=line-too-long
"""Constructs a `tf.data.Dataset`.
Expand Down Expand Up @@ -1007,9 +1006,6 @@ def as_dataset(
a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default, the returned
`tf.data.Dataset` will have a dictionary with all the features.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
Expand Down Expand Up @@ -1043,7 +1039,6 @@ def as_dataset(
decoders=decoders,
read_config=read_config,
as_supervised=as_supervised,
file_format=file_format,
)
all_ds = tree.map_structure(build_single_dataset, split)
return all_ds
Expand All @@ -1056,28 +1051,19 @@ def _build_single_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
as_supervised: bool,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""as_dataset for a single split."""
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.splits.total_num_examples or sys.maxsize

if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)

# Build base dataset
as_dataset_kwargs = {
"split": split,
"shuffle_files": shuffle_files,
"decoders": decoders,
"read_config": read_config,
}
# Not all dataset builder classes support file_format, so only pass it if
# it's supported.
if "file_format" in inspect.signature(self._as_dataset).parameters:
as_dataset_kwargs["file_format"] = file_format
ds = self._as_dataset(**as_dataset_kwargs)
ds = self._as_dataset(
split=split,
shuffle_files=shuffle_files,
decoders=decoders,
read_config=read_config,
)

# Auto-cache small datasets which are small enough to fit in memory.
if self._should_cache_ds(
Expand Down Expand Up @@ -1263,7 +1249,6 @@ def _as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
shuffle_files: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""Constructs a `tf.data.Dataset`.
Expand All @@ -1279,9 +1264,6 @@ def _as_dataset(
read_config: `tfds.ReadConfig`
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`
Expand Down Expand Up @@ -1525,14 +1507,16 @@ def _example_specs(self):
)
return self.info.features.get_serialized_info()

def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
def _as_dataset(
self,
split: splits_lib.Split,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
file_format: file_adapters.FileFormat | None = None,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
shuffle_files: bool = False,
) -> tf.data.Dataset:
if read_config is None:
read_config = read_config_lib.ReadConfig()

# Partial decoding
# TODO(epot): Should be moved inside `features.decode_example`
if isinstance(decoders, decode.PartialDecoding):
Expand All @@ -1550,10 +1534,18 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
f"Features are not set for dataset {self.name} in {self.data_dir}!"
)

file_format = (
read_config.file_format
or self.info.file_format
or file_adapters.DEFAULT_FILE_FORMAT
)
if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)

reader = reader_lib.Reader(
self.data_dir,
example_specs=example_specs,
file_format=file_format or self.info.file_format,
file_format=file_format,
)
decode_fn = functools.partial(features.decode_example, decoders=decoders)
return reader.read(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_datasets/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
import dataclasses
import difflib
import inspect
import posixpath
import re
import textwrap
Expand Down Expand Up @@ -671,9 +670,10 @@ def load(
as_dataset_kwargs.setdefault('batch_size', batch_size)
as_dataset_kwargs.setdefault('decoders', decoders)
as_dataset_kwargs.setdefault('shuffle_files', shuffle_files)
if file_format is not None:
read_config = read_config or read_config_lib.ReadConfig()
read_config = read_config.replace(file_format=file_format)
as_dataset_kwargs.setdefault('read_config', read_config)
if 'file_format' in inspect.signature(dbuilder.as_dataset).parameters:
as_dataset_kwargs.setdefault('file_format', file_format)

ds = dbuilder.as_dataset(**as_dataset_kwargs)
if with_info:
Expand Down
33 changes: 22 additions & 11 deletions tensorflow_datasets/core/utils/read_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from __future__ import annotations

from collections.abc import Sequence
import dataclasses
from typing import Callable, Optional, Sequence, Union, cast
from typing import Callable, cast

from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

Expand Down Expand Up @@ -91,36 +93,45 @@ class ReadConfig:
False if input files have been tempered with and they don't mind missing
records or have too many of them.
override_buffer_size: number of bytes to pass to file readers for buffering.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not specified,
the default file format is used.
"""
# pyformat: enable

# General tf.data.Dataset parametters
options: Optional[tf.data.Options] = None
options: tf.data.Options | None = None
try_autocache: bool = True
repeat_filenames: bool = False
add_tfds_id: bool = False
# tf.data.Dataset.shuffle parameters
shuffle_seed: Optional[int] = None
shuffle_reshuffle_each_iteration: Optional[bool] = None
shuffle_seed: int | None = None
shuffle_reshuffle_each_iteration: bool | None = None
# Interleave parameters
# Ideally, we should switch interleave values to None to dynamically set
# those value depending on the user system. However, this would make the
# generation order non-deterministic accross machines.
interleave_cycle_length: Union[Optional[int], _MISSING] = MISSING
interleave_block_length: Optional[int] = 16
input_context: Optional[tf.distribute.InputContext] = None
experimental_interleave_sort_fn: Optional[InterleaveSortFn] = None
interleave_cycle_length: int | None | _MISSING = MISSING
interleave_block_length: int | None = 16
input_context: tf.distribute.InputContext | None = None
experimental_interleave_sort_fn: InterleaveSortFn | None = None
skip_prefetch: bool = False
num_parallel_calls_for_decode: Optional[int] = None
num_parallel_calls_for_decode: int | None = None
# Cast to an `int`. `__post_init__` will ensure the type invariant.
num_parallel_calls_for_interleave_files: Optional[int] = cast(int, MISSING)
num_parallel_calls_for_interleave_files: int | None = cast(int, MISSING)
enable_ordering_guard: bool = True
assert_cardinality: bool = True
override_buffer_size: Optional[int] = None
override_buffer_size: int | None = None
file_format: str | file_adapters.FileFormat | None = None

def __post_init__(self):
self.options = self.options or tf.data.Options()
if self.num_parallel_calls_for_decode is None:
self.num_parallel_calls_for_decode = tf.data.AUTOTUNE
if self.num_parallel_calls_for_interleave_files == MISSING:
self.num_parallel_calls_for_interleave_files = tf.data.AUTOTUNE
if isinstance(self.file_format, str):
self.file_format = file_adapters.FileFormat.from_value(self.file_format)

def replace(self, **kwargs) -> ReadConfig:
return dataclasses.replace(self, **kwargs)

0 comments on commit 4721f7d

Please sign in to comment.