diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index bf38042eb81..af0e8f0a139 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -5,6 +5,7 @@ "Array4D", "Array5D", "ClassLabel", + "CustomFeature", "Features", "LargeList", "Sequence", @@ -14,8 +15,21 @@ "TranslationVariableLanguages", "Video", ] + + from .audio import Audio -from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value +from .features import ( + Array2D, + Array3D, + Array4D, + Array5D, + ClassLabel, + CustomFeature, + Features, + LargeList, + Sequence, + Value, +) from .image import Image from .translation import Translation, TranslationVariableLanguages from .video import Video diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 34622cd94d9..93c694bf7ae 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -458,6 +458,31 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False, optimize_list_cast )[0] +class CustomFeature: + """ + Base class for feature types like Audio, Image, ClassLabel, etc that require special treatment (encoding/decoding). + """ + + requires_encoding: ClassVar[bool] = False + requires_decoding: ClassVar[bool] = False + + def encode_example(self, example): + if self.requires_encoding: + return self._encode_example(example) + return example + + def _encode_example(self, example): + raise NotImplementedError("Should be implemented by child class if `requires_encoding` is True") + + def decode_example(self, example): + if self.requires_decoding: + return self._decode_example(example) + return example + + def _decode_example(self, example): + raise NotImplementedError("Should be implemented by child class if `requires_decoding` is True") + + @dataclass class Value: """ @@ -542,7 +567,7 @@ def __call__(self): pa_type = globals()[self.__class__.__name__ + "ExtensionType"](self.shape, self.dtype) return pa_type - def encode_example(self, value): + def _encode_example(self, value): return value @@ -1091,7 +1116,7 @@ def int2str(self, values: Union[int, Iterable]) -> Union[str, Iterable]: output = [self._int2str[int(v)] for v in values] return output if return_list else output[0] - def encode_example(self, example_data): + def _encode_example(self, example_data): if self.num_classes is None: raise ValueError( "Trying to use ClassLabel feature with undefined number of class. " @@ -1180,6 +1205,8 @@ class LargeList: Child feature data type of each item within the large list. """ + requires_encoding: ClassVar[bool] = True + requires_decoding: ClassVar[bool] = True feature: Any id: Optional[str] = None # Automatically constructed @@ -1203,6 +1230,7 @@ class LargeList: Array5D, Audio, Image, + CustomFeature, Video, ] @@ -1267,19 +1295,20 @@ def get_nested_type(schema: FeatureType) -> pa.DataType: return schema() -def encode_nested_example(schema, obj, level=0): +def encode_nested_example(schema, obj, is_nested: bool = False): """Encode a nested example. This is used since some features (in particular ClassLabel) have some logic during encoding. To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be encoded. If the first element needs to be encoded, then all the elements of the list will be encoded, otherwise they'll stay the same. """ + # Nested structures: we allow dict, list/tuples, sequences if isinstance(schema, dict): - if level == 0 and obj is None: + if not is_nested and obj is None: raise ValueError("Got None but expected a dictionary instead") return ( - {k: encode_nested_example(schema[k], obj.get(k), level=level + 1) for k in schema} + {k: encode_nested_example(schema[k], obj.get(k), is_nested=True) for k in schema} if obj is not None else None ) @@ -1295,9 +1324,10 @@ def encode_nested_example(schema, obj, level=0): for first_elmt in obj: if _check_non_null_non_empty_recursive(first_elmt, sub_schema): break - if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt: - return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj] + if encode_nested_example(sub_schema, first_elmt, is_nested=True) != first_elmt: + return [encode_nested_example(sub_schema, o, is_nested=True) for o in obj] return list(obj) + elif isinstance(schema, LargeList): if obj is None: return None @@ -1307,8 +1337,8 @@ def encode_nested_example(schema, obj, level=0): for first_elmt in obj: if _check_non_null_non_empty_recursive(first_elmt, sub_schema): break - if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt: - return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj] + if encode_nested_example(sub_schema, first_elmt, is_nested=True) != first_elmt: + return [encode_nested_example(sub_schema, o, is_nested=True) for o in obj] return list(obj) elif isinstance(schema, Sequence): if obj is None: @@ -1320,13 +1350,13 @@ def encode_nested_example(schema, obj, level=0): if isinstance(obj, (list, tuple)): # obj is a list of dict for k in schema.feature: - list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), level=level + 1) for o in obj] + list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), is_nested=True) for o in obj] return list_dict else: # obj is a single dict for k in schema.feature: list_dict[k] = ( - [encode_nested_example(schema.feature[k], o, level=level + 1) for o in obj[k]] + [encode_nested_example(schema.feature[k], o, is_nested=True) for o in obj[k]] if k in obj else None ) @@ -1342,14 +1372,18 @@ def encode_nested_example(schema, obj, level=0): # be careful when comparing tensors here if ( not isinstance(first_elmt, list) - or encode_nested_example(schema.feature, first_elmt, level=level + 1) != first_elmt + or encode_nested_example(schema.feature, first_elmt, is_nested=True) != first_elmt ): - return [encode_nested_example(schema.feature, o, level=level + 1) for o in obj] + return [encode_nested_example(schema.feature, o, is_nested=True) for o in obj] return list(obj) # Object with special encoding: # ClassLabel will convert from string to int, TranslationVariableLanguages does some checks elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)): return schema.encode_example(obj) if obj is not None else None + + # Custom features + elif isinstance(schema, CustomFeature) and schema.requires_encoding: + return schema.encode_example(obj) if obj is not None else None # Other object should be directly convertible to a native Arrow type (like Translation and Translation) return obj @@ -1403,6 +1437,10 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni # we pass the token to read and decode files from private repositories in streaming mode if obj is not None and schema.decode: return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) + # Custom features + elif isinstance(schema, CustomFeature) and schema.requires_decoding: + if obj is not None and schema.decode: + return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) return obj @@ -1432,6 +1470,9 @@ def register_feature( Register a Feature object using a name and class. This function must be used on a Feature class. """ + assert issubclass( + feature_cls, CustomFeature + ), f"Custom feature class {feature_cls.__name__} must inherit from datasets.CustomFeature" if feature_type in _FEATURE_TYPES: logger.warning( f"Overwriting feature type '{feature_type}' ({_FEATURE_TYPES[feature_type].__name__} -> {feature_cls.__name__})" @@ -1628,6 +1669,8 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False return require_decoding(feature.feature) elif isinstance(feature, Sequence): return require_decoding(feature.feature) + elif isinstance(feature, CustomFeature): + return feature.requires_decoding and (feature.decode if not ignore_decode_attribute else True) else: return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)