From 7aaba0ecc9025b15afe20930f4c9e517959d8147 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 28 Oct 2024 16:25:45 +0100 Subject: [PATCH] always set non-null writer batch size --- src/datasets/arrow_writer.py | 55 ++++++++++++++++++++++++++++-------- src/datasets/io/parquet.py | 42 ++------------------------- 2 files changed, 46 insertions(+), 51 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 763ad8b7a41..23fd8b94b87 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -24,10 +24,11 @@ from fsspec.core import url_to_fs from . import config -from .features import Features, Image, Value +from .features import Audio, Features, Image, Value, Video from .features.features import ( FeatureType, _ArrayXDExtensionType, + _visit, cast_to_python_objects, generate_from_arrow_type, get_nested_type, @@ -48,6 +49,45 @@ type_ = type # keep python's type function +def get_writer_batch_size(features: Optional[Features]) -> Optional[int]: + """ + Get the writer_batch_size that defines the maximum row group size in the parquet files. + The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos. + This allows to optimize random access to parquet file, since accessing 1 row requires + to read its entire row group. + + This can be improved to get optimized size for querying/iterating + but at least it matches the dataset viewer expectations on HF. + + Args: + features (`datasets.Features` or `None`): + Dataset Features from `datasets`. + Returns: + writer_batch_size (`Optional[int]`): + Writer batch size to pass to a dataset builder. + If `None`, then it will use the `datasets` default. + """ + if not features: + return None + + batch_size = np.inf + + def set_batch_size(feature: FeatureType) -> None: + nonlocal batch_size + if isinstance(feature, Image): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS) + elif isinstance(feature, Audio): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) + elif isinstance(feature, Video): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS) + elif isinstance(feature, Value) and feature.dtype == "binary": + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS) + + _visit(features, set_batch_size) + + return None if batch_size is np.inf else batch_size + + class SchemaInferenceError(ValueError): pass @@ -340,7 +380,9 @@ def __init__( self.fingerprint = fingerprint self.disable_nullable = disable_nullable - self.writer_batch_size = writer_batch_size + self.writer_batch_size = ( + writer_batch_size or get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE + ) self.update_features = update_features self.with_metadata = with_metadata self.unit = unit @@ -353,11 +395,6 @@ def __init__( self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None self.hkey_record = [] - if self.writer_batch_size is None and self._features is not None: - from .io.parquet import get_writer_batch_size - - self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE - def __len__(self): """Return the number of writed and staged examples""" return self._num_examples + len(self.current_examples) + len(self.current_rows) @@ -402,10 +439,6 @@ def _build_writer(self, inferred_schema: pa.Schema): schema = schema.with_metadata({}) self._schema = schema self.pa_writer = self._WRITER_CLASS(self.stream, schema) - if self.writer_batch_size is None: - from .io.parquet import get_writer_batch_size - - self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE @property def schema(self): diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 289bd1adfdc..d34f5110204 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -2,11 +2,10 @@ from typing import BinaryIO, Optional, Union import fsspec -import numpy as np import pyarrow.parquet as pq -from .. import Audio, Dataset, Features, Image, NamedSplit, Value, Video, config -from ..features.features import FeatureType, _visit +from .. import Dataset, Features, NamedSplit, config +from ..arrow_writer import get_writer_batch_size from ..formatting import query_table from ..packaged_modules import _PACKAGED_DATASETS_MODULES from ..packaged_modules.parquet.parquet import Parquet @@ -15,43 +14,6 @@ from .abc import AbstractDatasetReader -def get_writer_batch_size(features: Features) -> Optional[int]: - """ - Get the writer_batch_size that defines the maximum row group size in the parquet files. - The default in `datasets` is 1,000 but we lower it to 100 for image datasets. - This allows to optimize random access to parquet file, since accessing 1 row requires - to read its entire row group. - - This can be improved to get optimized size for querying/iterating - but at least it matches the dataset viewer expectations on HF. - - Args: - ds_config_info (`datasets.info.DatasetInfo`): - Dataset info from `datasets`. - Returns: - writer_batch_size (`Optional[int]`): - Writer batch size to pass to a dataset builder. - If `None`, then it will use the `datasets` default. - """ - - batch_size = np.inf - - def set_batch_size(feature: FeatureType) -> None: - nonlocal batch_size - if isinstance(feature, Image): - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS) - elif isinstance(feature, Audio): - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) - elif isinstance(feature, Video): - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS) - elif isinstance(feature, Value) and feature.dtype == "binary": - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS) - - _visit(features, set_batch_size) - - return None if batch_size is np.inf else batch_size - - class ParquetDatasetReader(AbstractDatasetReader): def __init__( self,