Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always set non-null writer batch size #7258

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from fsspec.core import url_to_fs

from . import config
from .features import Features, Image, Value
from .features import Audio, Features, Image, Value, Video
from .features.features import (
FeatureType,
_ArrayXDExtensionType,
_visit,
cast_to_python_objects,
generate_from_arrow_type,
get_nested_type,
Expand All @@ -48,6 +49,45 @@
type_ = type # keep python's type function


def get_writer_batch_size(features: Optional[Features]) -> Optional[int]:
"""
Get the writer_batch_size that defines the maximum row group size in the parquet files.
The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos.
This allows to optimize random access to parquet file, since accessing 1 row requires
to read its entire row group.

This can be improved to get optimized size for querying/iterating
but at least it matches the dataset viewer expectations on HF.

Args:
features (`datasets.Features` or `None`):
Dataset Features from `datasets`.
Returns:
writer_batch_size (`Optional[int]`):
Writer batch size to pass to a dataset builder.
If `None`, then it will use the `datasets` default.
"""
if not features:
return None

batch_size = np.inf

def set_batch_size(feature: FeatureType) -> None:
nonlocal batch_size
if isinstance(feature, Image):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
elif isinstance(feature, Audio):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
elif isinstance(feature, Video):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
elif isinstance(feature, Value) and feature.dtype == "binary":
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)

_visit(features, set_batch_size)

return None if batch_size is np.inf else batch_size


class SchemaInferenceError(ValueError):
pass

Expand Down Expand Up @@ -340,7 +380,9 @@ def __init__(

self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
self.writer_batch_size = writer_batch_size
self.writer_batch_size = (
writer_batch_size or get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE
)
self.update_features = update_features
self.with_metadata = with_metadata
self.unit = unit
Expand All @@ -353,11 +395,6 @@ def __init__(
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
self.hkey_record = []

if self.writer_batch_size is None and self._features is not None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

def __len__(self):
"""Return the number of writed and staged examples"""
return self._num_examples + len(self.current_examples) + len(self.current_rows)
Expand Down Expand Up @@ -402,10 +439,6 @@ def _build_writer(self, inferred_schema: pa.Schema):
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)
if self.writer_batch_size is None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

@property
def schema(self):
Expand Down
42 changes: 2 additions & 40 deletions src/datasets/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import BinaryIO, Optional, Union

import fsspec
import numpy as np
import pyarrow.parquet as pq

from .. import Audio, Dataset, Features, Image, NamedSplit, Value, Video, config
from ..features.features import FeatureType, _visit
from .. import Dataset, Features, NamedSplit, config
from ..arrow_writer import get_writer_batch_size
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
from ..packaged_modules.parquet.parquet import Parquet
Expand All @@ -15,43 +14,6 @@
from .abc import AbstractDatasetReader


def get_writer_batch_size(features: Features) -> Optional[int]:
"""
Get the writer_batch_size that defines the maximum row group size in the parquet files.
The default in `datasets` is 1,000 but we lower it to 100 for image datasets.
This allows to optimize random access to parquet file, since accessing 1 row requires
to read its entire row group.

This can be improved to get optimized size for querying/iterating
but at least it matches the dataset viewer expectations on HF.

Args:
ds_config_info (`datasets.info.DatasetInfo`):
Dataset info from `datasets`.
Returns:
writer_batch_size (`Optional[int]`):
Writer batch size to pass to a dataset builder.
If `None`, then it will use the `datasets` default.
"""

batch_size = np.inf

def set_batch_size(feature: FeatureType) -> None:
nonlocal batch_size
if isinstance(feature, Image):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
elif isinstance(feature, Audio):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
elif isinstance(feature, Video):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
elif isinstance(feature, Value) and feature.dtype == "binary":
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)

_visit(features, set_batch_size)

return None if batch_size is np.inf else batch_size


class ParquetDatasetReader(AbstractDatasetReader):
def __init__(
self,
Expand Down
Loading