Skip to content

Commit

Permalink
Add support for loading data in a specific file format
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704282560
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 9, 2024
1 parent 5d2c3e5 commit 7051dea
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 24 deletions.
76 changes: 59 additions & 17 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,9 @@ def get_reference(
data_dir=self.data_dir_root,
)

def get_file_spec(self, split: str) -> str:
def get_file_spec(self, split: str) -> str | None:
"""Returns the file spec of the split."""
split_info: splits_lib.SplitInfo = self.info.splits[split]
split_info = self.info.splits[split]
return split_info.file_spec(self.info.file_format)

def is_prepared(self) -> bool:
Expand Down Expand Up @@ -815,6 +815,7 @@ def as_data_source(
*,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
file_format: str | file_adapters.FileFormat | None = None,
) -> ListOrTreeOrElem[Sequence[Any]]:
"""Constructs an `ArrayRecordDataSource`.
Expand All @@ -833,6 +834,9 @@ def as_data_source(
the features. Decoding is only supported if the examples are tf
examples. Note that if the deserialize_method method is other than
PARSE_AND_DECODE, then the `decoders` argument is ignored.
file_format: if the dataset is stored in multiple file formats, then this
can be used to specify which format to use. If not provided, we will
default to the first available format.
Returns:
`Sequence` if `split`,
Expand Down Expand Up @@ -868,22 +872,31 @@ def as_data_source(
"Dataset info file format is not set! For random access, one of the"
f" following formats is required: {random_access_formats_msg}"
)

suitable_formats = available_formats.intersection(random_access_formats)
if suitable_formats:
if not suitable_formats:
raise NotImplementedError(unsupported_format_msg)

if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)
if file_format not in suitable_formats:
raise ValueError(
f"Requested file format {file_format} is not available for this"
f" dataset. Available formats: {available_formats}"
)
chosen_format = file_format
else:
chosen_format = suitable_formats.pop()
logging.info(
"Found random access formats: %s. Chose to use %s. Overriding file"
" format in the dataset info.",
", ".join([f.name for f in suitable_formats]),
chosen_format,
)
# Change the dataset info to read from a random access format.
info.set_file_format(
chosen_format, override=True, override_if_initialized=True
)
else:
raise NotImplementedError(unsupported_format_msg)

# Change the dataset info to read from a random access format.
info.set_file_format(
chosen_format, override=True, override_if_initialized=True
)

# Create a dataset for each of the given splits
def build_single_data_source(split: str) -> Sequence[Any]:
Expand Down Expand Up @@ -924,6 +937,7 @@ def as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
as_supervised: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
):
# pylint: disable=line-too-long
"""Constructs a `tf.data.Dataset`.
Expand Down Expand Up @@ -993,6 +1007,9 @@ def as_dataset(
a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default, the returned
`tf.data.Dataset` will have a dictionary with all the features.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
Expand Down Expand Up @@ -1026,6 +1043,7 @@ def as_dataset(
decoders=decoders,
read_config=read_config,
as_supervised=as_supervised,
file_format=file_format,
)
all_ds = tree.map_structure(build_single_dataset, split)
return all_ds
Expand All @@ -1038,19 +1056,29 @@ def _build_single_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
as_supervised: bool,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""as_dataset for a single split."""
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.splits.total_num_examples or sys.maxsize

if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)

# Build base dataset
ds = self._as_dataset(
split=split,
shuffle_files=shuffle_files,
decoders=decoders,
read_config=read_config,
)
as_dataset_kwargs = {
"split": split,
"shuffle_files": shuffle_files,
"decoders": decoders,
"read_config": read_config,
}
# Not all dataset builder classes support file_format, so only pass it if
# it's supported.
if "file_format" in inspect.signature(self._as_dataset).parameters:
as_dataset_kwargs["file_format"] = file_format
ds = self._as_dataset(**as_dataset_kwargs)

# Auto-cache small datasets which are small enough to fit in memory.
if self._should_cache_ds(
split=split, shuffle_files=shuffle_files, read_config=read_config
Expand Down Expand Up @@ -1235,6 +1263,7 @@ def _as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
shuffle_files: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""Constructs a `tf.data.Dataset`.
Expand All @@ -1250,6 +1279,9 @@ def _as_dataset(
read_config: `tfds.ReadConfig`
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`
Expand Down Expand Up @@ -1487,6 +1519,10 @@ def __init__(

@functools.cached_property
def _example_specs(self):
if self.info.features is None:
raise ValueError(
f"Features are not set for dataset {self.name} in {self.data_dir}!"
)
return self.info.features.get_serialized_info()

def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
Expand All @@ -1495,6 +1531,7 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
file_format: file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
# Partial decoding
# TODO(epot): Should be moved inside `features.decode_example`
Expand All @@ -1508,10 +1545,15 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
example_specs = self._example_specs
decoders = decoders # pylint: disable=self-assigning-variable

if features is None:
raise ValueError(
f"Features are not set for dataset {self.name} in {self.data_dir}!"
)

reader = reader_lib.Reader(
self.data_dir,
example_specs=example_specs,
file_format=self.info.file_format,
file_format=file_format or self.info.file_format,
)
decode_fn = functools.partial(features.decode_example, decoders=decoders)
return reader.read(
Expand Down
20 changes: 15 additions & 5 deletions tensorflow_datasets/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
import dataclasses
import difflib
import inspect
import posixpath
import re
import textwrap
Expand Down Expand Up @@ -226,7 +227,7 @@ def _try_load_from_files_first(
**builder_kwargs: Any,
) -> bool:
"""Returns True if files should be used rather than code."""
if set(builder_kwargs) - {'version', 'config', 'data_dir'}:
if set(builder_kwargs) - {'version', 'config', 'data_dir', 'file_format'}:
return False # Has extra kwargs, requires original code.
elif builder_kwargs.get('version') == 'experimental_latest':
return False # Requested version requires original code
Expand Down Expand Up @@ -485,10 +486,13 @@ def _fetch_builder(
data_dir: epath.PathLike | None,
builder_kwargs: dict[str, Any] | None,
try_gcs: bool,
file_format: str | file_adapters.FileFormat | None = None,
) -> dataset_builder.DatasetBuilder:
"""Fetches the `tfds.core.DatasetBuilder` by name."""
if builder_kwargs is None:
builder_kwargs = {}
if file_format is not None:
builder_kwargs['file_format'] = file_format
return builder(name, data_dir=data_dir, try_gcs=try_gcs, **builder_kwargs)


Expand Down Expand Up @@ -529,6 +533,7 @@ def load(
download_and_prepare_kwargs: dict[str, Any] | None = None,
as_dataset_kwargs: dict[str, Any] | None = None,
try_gcs: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
):
# pylint: disable=line-too-long
"""Loads the named dataset into a `tf.data.Dataset`.
Expand Down Expand Up @@ -636,6 +641,9 @@ def load(
fully bypass GCS, please use `try_gcs=False` and
`download_and_prepare_kwargs={'download_config':
tfds.core.download.DownloadConfig(try_download_gcs=False)})`.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not specified,
the default file format is used.
Returns:
ds: `tf.data.Dataset`, the dataset requested, or if `split` is None, a
Expand All @@ -648,10 +656,10 @@ def load(
Split-specific information is available in `ds_info.splits`.
""" # fmt: skip
dbuilder = _fetch_builder(
name,
data_dir,
builder_kwargs,
try_gcs,
name=name,
data_dir=data_dir,
builder_kwargs=builder_kwargs,
try_gcs=try_gcs,
)
_download_and_prepare_builder(dbuilder, download, download_and_prepare_kwargs)

Expand All @@ -664,6 +672,8 @@ def load(
as_dataset_kwargs.setdefault('decoders', decoders)
as_dataset_kwargs.setdefault('shuffle_files', shuffle_files)
as_dataset_kwargs.setdefault('read_config', read_config)
if 'file_format' in inspect.signature(dbuilder.as_dataset).parameters:
as_dataset_kwargs.setdefault('file_format', file_format)

ds = dbuilder.as_dataset(**as_dataset_kwargs)
if with_info:
Expand Down
23 changes: 21 additions & 2 deletions tensorflow_datasets/core/read_only_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from etils import etree
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import logging as tfds_logging
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
Expand All @@ -57,6 +58,7 @@ def __init__(
builder_dir: epath.PathLike,
*,
info_proto: dataset_info_pb2.DatasetInfo | None = None,
file_format: str | file_adapters.FileFormat | None = None,
):
"""Constructor.
Expand All @@ -66,6 +68,8 @@ def __init__(
info_proto: DatasetInfo describing the name, config, etc of the requested
dataset. Note that this overwrites dataset info that may be present in
builder_dir.
file_format: The desired file format to use for the dataset. If not
specified, the file format in the DatasetInfo is used.
Raises:
FileNotFoundError: If the builder_dir does not exist.
Expand All @@ -74,6 +78,15 @@ def __init__(
if not info_proto:
info_proto = dataset_info.read_proto_from_builder_dir(builder_dir)
self._info_proto = info_proto
if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)
available_formats = set([self._info_proto.file_format])
available_formats.update(self._info_proto.alternative_file_formats)
if file_format.file_suffix not in available_formats:
raise ValueError(
f'File format {file_format.file_suffix} does not match the file'
f' formats in the DatasetInfo: {sorted(available_formats)}.'
)

self.name = info_proto.name
self.VERSION = version_lib.Version(info_proto.version) # pylint: disable=invalid-name
Expand All @@ -92,6 +105,7 @@ def __init__(
data_dir=builder_dir,
config=builder_config,
version=info_proto.version,
file_format=file_format,
)
self.assert_is_not_blocked()

Expand Down Expand Up @@ -154,6 +168,7 @@ def _download_and_prepare(self, **kwargs): # pylint: disable=arguments-differ

def builder_from_directory(
builder_dir: epath.PathLike,
file_format: str | file_adapters.FileFormat | None = None,
) -> dataset_builder.DatasetBuilder:
"""Loads a `tfds.core.DatasetBuilder` from the given generated dataset path.
Expand All @@ -171,11 +186,13 @@ def builder_from_directory(
Args:
builder_dir: Path of the directory containing the dataset to read ( e.g.
`~/tensorflow_datasets/mnist/3.0.0/`).
file_format: The desired file format to use for the dataset. If not
specified, the default file format in the DatasetInfo is used.
Returns:
builder: `tfds.core.DatasetBuilder`, builder for dataset at the given path.
"""
return ReadOnlyBuilder(builder_dir=builder_dir)
return ReadOnlyBuilder(builder_dir=builder_dir, file_format=file_format)


def builder_from_directories(
Expand Down Expand Up @@ -308,7 +325,8 @@ def builder_from_files(
f'and that it has been generated in: {data_dirs}. If the dataset has'
' configs, you might have to specify the config name.'
)
return builder_from_directory(builder_dir)
file_format = builder_kwargs.pop('file_format', None)
return builder_from_directory(builder_dir, file_format=file_format)


def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
Expand Down Expand Up @@ -339,6 +357,7 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
version = str(version) if version else None
config = builder_kwargs.pop('config', None)
data_dir = builder_kwargs.pop('data_dir', None)
_ = builder_kwargs.pop('file_format', None)

# Builder cannot be found if it uses:
# * namespace
Expand Down

0 comments on commit 7051dea

Please sign in to comment.