diff --git a/src/transformers/models/altclip/processing_altclip.py b/src/transformers/models/altclip/processing_altclip.py index 2814b2d7f26e89..51ea3032053c3d 100644 --- a/src/transformers/models/altclip/processing_altclip.py +++ b/src/transformers/models/altclip/processing_altclip.py @@ -16,10 +16,24 @@ Image/Text processor class for AltCLIP """ +import sys import warnings +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class AltCLIPProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class AltCLIPProcessor(ProcessorMixin): @@ -59,7 +73,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[AltCLIPProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not @@ -68,24 +89,16 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -97,19 +110,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + output_kwargs = self._merge_kwargs( + AltCLIPProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index 7272093715f882..b9d0d41377bfde 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -115,8 +115,8 @@ def get_resize_output_image_size( new_width = scale * new_width new_height, new_width = int(new_height + 0.5), int(new_width + 0.5) - new_height = new_height // size_divisor * size_divisor - new_width = new_width // size_divisor * size_divisor + new_height = max(1, new_height // size_divisor) * size_divisor + new_width = max(1, new_width // size_divisor) * size_divisor return new_height, new_width @@ -238,9 +238,7 @@ def resize( The channel dimension format of the input image. If not provided, it will be inferred. """ size = get_size_dict(size, default_to_square=False) - if "shortest_edge" not in size: - raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") - shorter = size["shortest_edge"] + shorter = size["shortest_edge"] if "shortest_edge" in size else min(size["height"], size["width"]) longer = int(1333 / 800 * shorter) output_size = get_resize_output_image_size( image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index a23fdbed028867..2b0bd0024f3be1 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -44,7 +44,8 @@ import PIL -def make_batched_images(images) -> List[List[ImageInput]]: +# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images +def make_batched_images(images) -> List[ImageInput]: """ Accepts images in list or nested list format, and makes a list of images for preprocessing. diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 1480808336d14e..14d759ec6dcf87 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -16,13 +16,36 @@ Processor class for Chameleon. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class ChameleonTextKwargs(TextKwargs, total=False): + return_for_text_completion: bool + + +class ChameleonProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: ChameleonTextKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_for_text_completion": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } class ChameleonProcessor(ProcessorMixin): @@ -57,13 +80,9 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - return_for_text_completion: bool = False, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + **kwargs: Unpack[ChameleonProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -73,33 +92,13 @@ def __call__( of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -114,6 +113,15 @@ def __call__( text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") + if text is None and images is None: + raise ValueError("You must provide either text or images as prompt") + + output_kwargs = self._merge_kwargs( + ChameleonProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False) # Replace the image token with the expanded image token sequence prompt_strings = [] @@ -124,19 +132,10 @@ def __call__( sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - data = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) - + features = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] - data["pixel_values"] = pixel_values - - return BatchFeature(data=data, tensor_type=return_tensors) + features["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + return features # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/clap/processing_clap.py b/src/transformers/models/clap/processing_clap.py index 4d1739ecf26172..4796b8d8ea11c6 100644 --- a/src/transformers/models/clap/processing_clap.py +++ b/src/transformers/models/clap/processing_clap.py @@ -16,8 +16,22 @@ Audio/Text processor class for CLAP """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class ClapProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class ClapProcessor(ProcessorMixin): @@ -40,7 +54,14 @@ class ClapProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) - def __call__(self, text=None, audios=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audios: Optional[AudioInput] = None, + images=None, + videos=None, + **kwargs: Unpack[ClapProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` and `kwargs` arguments to RobertaTokenizerFast's [`~RobertaTokenizerFast.__call__`] if `text` is not `None` to @@ -49,25 +70,17 @@ def __call__(self, text=None, audios=None, return_tensors=None, **kwargs): doctsring of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + audio (`AudioInput`, *optional*): The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the sample length of the audio. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -75,26 +88,24 @@ def __call__(self, text=None, audios=None, return_tensors=None, **kwargs): `None`). - **audio_features** -- Audio features to be fed to a model. Returned when `audios` is not `None`. """ - sampling_rate = kwargs.pop("sampling_rate", None) if text is None and audios is None: raise ValueError("You have to specify either text or audios. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + output_kwargs = self._merge_kwargs( + ClapProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) if audios is not None: - audio_features = self.feature_extractor( - audios, sampling_rate=sampling_rate, return_tensors=return_tensors, **kwargs - ) - - if text is not None and audios is not None: - encoding.update(audio_features) - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**audio_features), tensor_type=return_tensors) + audio_features = self.feature_extractor(audios, **output_kwargs["audio_kwargs"]) + data.update(audio_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/clvp/processing_clvp.py b/src/transformers/models/clvp/processing_clvp.py index 4e015cea1f8475..6946a034341d54 100644 --- a/src/transformers/models/clvp/processing_clvp.py +++ b/src/transformers/models/clvp/processing_clvp.py @@ -17,7 +17,27 @@ Processor class for CLVP """ -from ...processing_utils import ProcessorMixin +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class ClvpAudioProcessorKwargs(AudioKwargs, total=False): + raw_speech: Optional[AudioInput] + + +class ClvpProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: ClvpAudioProcessorKwargs + _defaults = {} class ClvpProcessor(ProcessorMixin): @@ -45,33 +65,67 @@ class ClvpProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) - def __call__(self, *args, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio: Optional[AudioInput] = None, + images=None, + videos=None, + **kwargs: Unpack[ClvpProcessorKwargs], + ) -> BatchFeature: """ Forwards the `audio` and `sampling_rate` arguments to [`~ClvpFeatureExtractor.__call__`] and the `text` argument to [`~ClvpTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. + + Args: + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audio (`AudioInput`, *optional*): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **audio_features** -- Audio features to be fed to a model. Returned when `audios` is not `None`. """ - raw_speech = kwargs.pop("raw_speech", None) - sampling_rate = kwargs.pop("sampling_rate", None) - text = kwargs.pop("text", None) + output_kwargs = self._merge_kwargs( + ClvpProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + raw_speech = output_kwargs["audio_kwargs"].pop("raw_speech", None) + + if audio is not None and raw_speech is not None: + raise ValueError("Only one of `audio` and `raw_speech` must be specified.") + if audio is None and raw_speech is not None: + audio = raw_speech - if raw_speech is None and text is None: - raise ValueError("You need to specify either an `raw_speech` or `text` input to process.") + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") - if raw_speech is not None: - inputs = self.feature_extractor(raw_speech, sampling_rate=sampling_rate, **kwargs) + data = {} + if audio is not None: + audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + data.update(audio_features) if text is not None: - encodings = self.tokenizer(text, **kwargs) - - if text is None: - return inputs - elif raw_speech is None: - return encodings - else: - inputs["input_ids"] = encodings["input_ids"] - inputs["attention_mask"] = encodings["attention_mask"] - return inputs + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if audio is not None: + data["input_ids"] = text_features["input_ids"] + data["attention_mask"] = text_features["attention_mask"] + else: + data.update(text_features) + return BatchFeature(data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) # Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py index 7f439b040a8fd0..ace0434d0bd2f7 100644 --- a/src/transformers/models/flava/processing_flava.py +++ b/src/transformers/models/flava/processing_flava.py @@ -16,13 +16,42 @@ Image/Text processor class for FLAVA """ +import sys import warnings from typing import List, Optional, Union +from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class FlavaImagesKwargs(ImagesKwargs, total=False): + return_image_mask: Optional[bool] + return_codebook_pixels: Optional[bool] + + +class FlavaProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: FlavaImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + } class FlavaProcessor(ProcessorMixin): @@ -64,69 +93,51 @@ def __call__( self, images: Optional[ImageInput] = None, text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_image_mask: Optional[bool] = None, - return_codebook_pixels: Optional[bool] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ): + audio=None, + videos=None, + **kwargs: Unpack[FlavaProcessorKwargs], + ) -> BatchFeature: """ This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. - Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`, *optional*): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") + output_kwargs = self._merge_kwargs( + FlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + data = {} if text is not None: - encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor( - images, - return_image_mask=return_image_mask, - return_codebook_pixels=return_codebook_pixels, - return_tensors=return_tensors, - **kwargs, - ) - - if text is not None and images is not None: - encoding.update(image_features) - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 98649c644e728c..5abb1990233ac9 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -16,8 +16,23 @@ Image/Text processor class for GIT """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class GitProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class GitProcessor(ProcessorMixin): @@ -42,7 +57,14 @@ def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[GitProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -51,24 +73,16 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -76,29 +90,24 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - tokenizer_kwargs, image_processor_kwargs = {}, {} - if kwargs: - tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys} - image_processor_kwargs = { - k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys - } if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs) + output_kwargs = self._merge_kwargs( + GitProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py index 131b8fe57bd665..093aab0c4d2cb6 100644 --- a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py @@ -47,18 +47,25 @@ logger = logging.get_logger(__name__) +# Copied from transformers.models.vivit.image_processing_vivit.make_batched_videos def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], PIL.Image.Image): + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: return [videos] elif len(videos[0].shape) == 4: - return [list(video) for video in videos] + return videos - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] + elif is_valid_image(videos): + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index f56f8186b07d73..6e6ee8eb865aba 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -17,26 +17,44 @@ """ import os +import sys +import warnings from typing import List, Optional, Union -from ...image_processing_utils import BatchFeature +from ...feature_extraction_utils import BatchFeature from ...image_utils import VideoInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import ( - AddedToken, - BatchEncoding, - PaddingStrategy, - PreTokenizedInput, - TextInput, - TruncationStrategy, -) -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput +from ...utils import logging from ..auto import AutoTokenizer +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + logger = logging.get_logger(__name__) +class InstructBlipVideoProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": None, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + } + + class InstructBlipVideoProcessor(ProcessorMixin): r""" Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single @@ -71,30 +89,59 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query def __call__( self, - images: VideoInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + images: Optional[VideoInput] = None, # Keeping this here for backwards compatibility + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + videos: Optional[VideoInput] = None, + audio=None, + **kwargs: Unpack[InstructBlipVideoProcessorKwargs], ) -> BatchFeature: """ This method uses [`InstructBlipVideoImageProcessor.__call__`] method to prepare image(s) or video(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. - Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`, *optional*): + NOTE: Use `videos` instead. We only left this here for backwards compatibility. + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`VideoInput`, *optional*): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + -- **qformer_input_ids** - List of token ids from the Q-Former tokenizer to be fed to a model. Returned when `text` is not `None`. + -- **qformer_attention_mask** - List of indices specifying which tokens from the Q-Former tokenizer should be attended to by the model. Returned when `text` is not `None`. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if images is not None: + warnings.warn( + "The `images` argument is deprecated and will be removed in future versions, use `videos` instead.", + FutureWarning, + ) + if images is not None and videos is not None: + raise ValueError( + "You cannot provide both `images` and `videos` at the same time. Please pass video data as `videos=...` instead." + ) + if images is not None and videos is None: + videos = images + + output_kwargs = self._merge_kwargs( + InstructBlipVideoProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + encoding = BatchFeature() if text is not None: @@ -105,26 +152,15 @@ def __call__( _text_encoding = self.tokenizer( text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=None, # required to concatenate below - **kwargs, + **{ + **output_kwargs["text_kwargs"], + "return_tensors": None, # required to concatenate below + }, ) # if we know how many query tokens, expand text inside processor. We need this hacky manipulation # because BLIP expects image tokens to be at the beginning even before BOS token - if self.num_query_tokens is not None and images is not None: + if self.num_query_tokens is not None and videos is not None: text_encoding = {} video_tokens = ( self.video_token.content * self.num_query_tokens * 4 @@ -137,7 +173,7 @@ def __call__( ] else: text_encoding = _text_encoding - if images is not None: + if videos is not None: logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " @@ -145,31 +181,16 @@ def __call__( ) # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) - encoding.update(text_encoding) - qformer_text_encoding = self.qformer_tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, + text_encoding = BatchFeature( + text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors") ) + encoding.update(text_encoding) + qformer_text_encoding = self.qformer_tokenizer(text=text, **output_kwargs["text_kwargs"]) encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") - if images is not None: - image_encoding = self.image_processor(images, return_tensors=return_tensors) + if videos is not None: + image_encoding = self.image_processor(videos, **output_kwargs["images_kwargs"]) encoding.update(image_encoding) return encoding diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index f744b9fcf9c1cd..c5a0eaa63739c2 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -53,7 +53,7 @@ from PIL import Image -def make_batched_images(images) -> List[List[ImageInput]]: +def make_batched_images(images) -> List[ImageInput]: """ Accepts images in list or nested list format, and makes a list of images for preprocessing. @@ -720,7 +720,11 @@ def preprocess( image_patches = self.get_image_patches( image, image_grid_pinpoints, - size=(size["shortest_edge"], size["shortest_edge"]), + size=( + (size["shortest_edge"], size["shortest_edge"]) + if "shortest_edge" in size + else (size["height"], size["width"]) + ), patch_size=crop_size["height"], resample=resample, data_format=input_data_format, diff --git a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py index e16e71875bb2c8..1000d8de635699 100644 --- a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py @@ -46,21 +46,28 @@ if is_vision_available(): - from PIL import Image + import PIL +# Copied from transformers.models.vivit.image_processing_vivit.make_batched_videos def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: return [videos] elif len(videos[0].shape) == 4: - return [list(video) for video in videos] + return videos - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] + elif is_valid_image(videos): + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") @@ -212,7 +219,7 @@ def _preprocess( do_convert_rgb: bool = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> Image.Image: + ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index efbb193ba62a9f..ad85ff8d15e266 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -16,13 +16,20 @@ Processor class for LLaVa-NeXT-Video. """ +import sys from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack if TYPE_CHECKING: @@ -31,6 +38,17 @@ logger = logging.get_logger(__name__) +class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + class LlavaNextVideoProcessor(ProcessorMixin): r""" Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and @@ -88,12 +106,10 @@ def __init__( def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], - images: ImageInput = None, - videos: VideoInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + audio=None, + **kwargs: Unpack[LlavaNextVideoProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -105,36 +121,16 @@ def __call__( of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + videos (`VideoInput`, *optional*): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -143,15 +139,22 @@ def __call__( - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_images** -- Pixel values of images to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. """ + output_kwargs = self._merge_kwargs( + LlavaNextVideoProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: - image_inputs = self.image_processor(images, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} if videos is not None: - videos_inputs = self.video_processor(videos, return_tensors=return_tensors) + videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) else: videos_inputs = {} @@ -160,8 +163,6 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys()) - if self.patch_size is None or self.vision_feature_select_strategy is None: prompt_strings = text logger.warning_once( @@ -203,16 +204,12 @@ def __call__( sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) - print(text_inputs.keys()) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) - return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + return BatchFeature( + data={**text_inputs, **image_inputs, **videos_inputs}, + tensor_type=output_kwargs["common_kwargs"].get("return_tensors"), + ) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/mgp_str/processing_mgp_str.py b/src/transformers/models/mgp_str/processing_mgp_str.py index 207d4230ba09b7..4a3bdba95ad829 100644 --- a/src/transformers/models/mgp_str/processing_mgp_str.py +++ b/src/transformers/models/mgp_str/processing_mgp_str.py @@ -14,13 +14,24 @@ # limitations under the License. """Processor class for MGP-STR.""" +import sys import warnings +from typing import List, Optional, Union from transformers import AutoTokenizer -from transformers.utils import is_torch_available -from transformers.utils.generic import ExplicitEnum -from ...processing_utils import ProcessorMixin +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils.generic import ExplicitEnum +from ...utils.import_utils import is_torch_available + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack if is_torch_available(): @@ -36,6 +47,10 @@ class DecodeType(ExplicitEnum): SUPPORTED_ANNOTATION_FORMATS = (DecodeType.CHARACTER, DecodeType.BPE, DecodeType.WORDPIECE) +class MgpstrProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + class MgpstrProcessor(ProcessorMixin): r""" Constructs a MGP-STR processor which wraps an image processor and MGP-STR tokenizers into a single @@ -50,9 +65,9 @@ class MgpstrProcessor(ProcessorMixin): The tokenizer is a required input. """ - attributes = ["image_processor", "char_tokenizer"] + attributes = ["image_processor", "tokenizer"] image_processor_class = "ViTImageProcessor" - char_tokenizer_class = "MgpstrTokenizer" + tokenizer_class = "MgpstrTokenizer" def __init__(self, image_processor=None, tokenizer=None, **kwargs): feature_extractor = None @@ -70,34 +85,87 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") - self.char_tokenizer = tokenizer + self.tokenizer = tokenizer self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + @property + def char_tokenizer(self): + warnings.warn( + "The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.", + FutureWarning, + ) + return self.tokenizer + + @char_tokenizer.setter + def char_tokenizer(self, value): + warnings.warn( + "The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.", + FutureWarning, + ) + self.tokenizer = value + + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[MgpstrProcessorKwargs], + ) -> BatchFeature: """ When used in normal mode, this method forwards all its arguments to ViTImageProcessor's [`~ViTImageProcessor.__call__`] and returns its output. This method also forwards the `text` and `kwargs` arguments to MgpstrTokenizer's [`~MgpstrTokenizer.__call__`] if `text` is not `None` to encode the text. Please refer to the doctsring of the above methods for more information. + + Args: + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`ImageInput`, *optional*): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **labels** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") - if images is not None: - inputs = self.image_processor(images, return_tensors=return_tensors, **kwargs) - if text is not None: - encodings = self.char_tokenizer(text, return_tensors=return_tensors, **kwargs) + output_kwargs = self._merge_kwargs( + MgpstrProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) - if text is None: - return inputs - elif images is None: - return encodings - else: - inputs["labels"] = encodings["input_ids"] - return inputs + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + if "input_ids" in data: + # For backwards compatibility. MGP-STR doesn't actually use the labels, but the tests do. + # And users also expect the labels--and only the labels--to be returned. + # This requirement, however, may be relaxed in future versions. + data = { + "pixel_values": image_features["pixel_values"], + "labels": data["input_ids"], + } + else: + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, sequences): """ @@ -201,7 +269,7 @@ def char_decode(self, sequences): Returns: `List[str]`: The list of char decoded sentences. """ - decode_strs = [seq.replace(" ", "") for seq in self.char_tokenizer.batch_decode(sequences)] + decode_strs = [seq.replace(" ", "") for seq in self.tokenizer.batch_decode(sequences)] return decode_strs def bpe_decode(self, sequences): diff --git a/src/transformers/models/musicgen_melody/processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/processing_musicgen_melody.py index 34b1d1ec4d6d89..4379add2758f0a 100644 --- a/src/transformers/models/musicgen_melody/processing_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/processing_musicgen_melody.py @@ -16,14 +16,27 @@ Text/audio processor class for MusicGen Melody """ -from typing import List, Optional +import sys +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput from ...utils import to_numpy +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class MusicgenMelodyProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + class MusicgenMelodyProcessor(ProcessorMixin): r""" Constructs a MusicGen Melody processor which wraps a Wav2Vec2 feature extractor - for raw audio waveform processing - and a T5 tokenizer into a single processor @@ -42,14 +55,18 @@ class MusicgenMelodyProcessor(ProcessorMixin): feature_extractor_class = "MusicgenMelodyFeatureExtractor" tokenizer_class = ("T5Tokenizer", "T5TokenizerFast") - def __init__(self, feature_extractor, tokenizer): - super().__init__(feature_extractor, tokenizer) - # Copied from transformers.models.musicgen.processing_musicgen.MusicgenProcessor.get_decoder_prompt_ids def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps) - def __call__(self, audio=None, text=None, **kwargs): + def __call__( + self, + audio: Optional[AudioInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images=None, + videos=None, + **kwargs: Unpack[MusicgenMelodyProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio` and `kwargs` arguments to MusicgenMelodyFeatureExtractor's [`~MusicgenMelodyFeatureExtractor.__call__`] if `audio` is not @@ -57,41 +74,42 @@ def __call__(self, audio=None, text=None, **kwargs): PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information. Args: - audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + audio (`AudioInput`, *optional*): The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each audio should be a mono-stereo signal of shape (T), where T is the sample length of the audio. - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - kwargs (*optional*): - Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the - tokenizer. + Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`. - **attention_mask** -- List of token indices specifying which tokens should be attended to by the model when `text` is not `None`. When only `audio` is specified, returns the timestamps attention mask. """ - sampling_rate = kwargs.pop("sampling_rate", None) - if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") + output_kwargs = self._merge_kwargs( + MusicgenMelodyProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + data = {} if text is not None: - inputs = self.tokenizer(text, **kwargs) + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) if audio is not None: - audio_inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs) - - if text is None: - return audio_inputs - elif audio is None: - return inputs - else: - inputs["input_features"] = audio_inputs["input_features"] - return inputs + audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + if text is not None: + data["input_features"] = audio_features["input_features"] + else: + data.update(audio_features) + return BatchFeature(data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) # Copied from transformers.models.musicgen.processing_musicgen.MusicgenProcessor.batch_decode with padding_mask->attention_mask def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index eabf5b7069f200..406145b82fdc0f 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -16,13 +16,29 @@ Processor class for Qwen2Audio. """ +import sys from typing import List, Optional, Union -import numpy as np - from ...feature_extraction_utils import BatchFeature -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class Qwen2AudioProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "audio_kwargs": { + "padding": False, + }, + } class Qwen2AudioProcessor(ProcessorMixin): @@ -38,8 +54,8 @@ class Qwen2AudioProcessor(ProcessorMixin): tokenizer ([`Qwen2TokenizerFast`], *optional*): The tokenizer is a required input. chat_template (`Optional[str]`, *optional*): - The Jinja template to use for formatting the conversation. If not provided, the default chat template - is used. + The Jinja template to use for formatting the conversation. If not provided, the default chat template + is used. """ attributes = ["feature_extractor", "tokenizer"] @@ -53,11 +69,11 @@ def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None): def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - audios: Union[np.ndarray, List[np.ndarray]] = None, - padding: Union[bool, str, PaddingStrategy] = False, - sampling_rate: Optional[int] = None, - **kwargs, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audios: Optional[AudioInput] = None, + images=None, + videos=None, + **kwargs: Unpack[Qwen2AudioProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` @@ -67,39 +83,44 @@ def __call__( of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - audios (`np.ndarray`, `List[np.ndarray]`): + audios (`AudioInput`, *optional*): The audio or batch of audios to be prepared. Each audio can be a NumPy array. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - sampling_rate (`int`, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`. """ if text is None: raise ValueError("You need to specify either a `text` input to process.") - inputs = self.tokenizer(text, padding=padding, **kwargs) + + output_kwargs = self._merge_kwargs( + Qwen2AudioProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # Temporary fix for "paddding_side" in init_kwargs + _ = output_kwargs["text_kwargs"].pop("padding_side", None) + + data = self.tokenizer(text, **output_kwargs["text_kwargs"]) if audios is not None: - audio_inputs = self.feature_extractor( - audios, sampling_rate=sampling_rate, return_attention_mask=True, padding="max_length", **kwargs - ) + audio_inputs = self.feature_extractor(audios, **output_kwargs["audio_kwargs"]) audio_inputs["feature_attention_mask"] = audio_inputs.pop( "attention_mask" ) # rename attention_mask to prevent conflicts later on - inputs.update(audio_inputs) + data.update(audio_inputs) - return BatchFeature(data={**inputs}) + return BatchFeature(data={**data}, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/seamless_m4t/processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/processing_seamless_m4t.py index 7e838913ca147c..866bdf33aa82e9 100644 --- a/src/transformers/models/seamless_m4t/processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/processing_seamless_m4t.py @@ -16,7 +16,44 @@ Audio/Text processor class for SeamlessM4T """ -from ...processing_utils import ProcessorMixin +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class SeamlessM4TProcessorAudioKwargs(TextKwargs, total=False): + do_normalize_per_mel_bins: Optional[bool] + + +class SeamlessM4TProcessorTextKwargs(TextKwargs, total=False): + src_lang: Optional[str] + tgt_lang: Optional[str] + + +class SeamlessM4TProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: SeamlessM4TProcessorTextKwargs + audio_kwargs: SeamlessM4TProcessorAudioKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "pad_to_multiple_of": 2, + }, + "audio_kwargs": { + "do_normalize_per_mel_bins": True, + "padding": True, + "pad_to_multiple_of": 2, + "truncation": True, + }, + } class SeamlessM4TProcessor(ProcessorMixin): @@ -41,7 +78,14 @@ class SeamlessM4TProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) - def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audios: Optional[AudioInput] = None, + images=None, + videos=None, + **kwargs: Unpack[SeamlessM4TProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not @@ -50,24 +94,17 @@ def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwarg to the doctsring of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + audios (`AudioInput`, *optional*): The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the sample length of the audio. - src_lang (`str`, *optional*): - The language code of the input texts/audios. If not specified, the last `src_lang` specified will be - used. - tgt_lang (`str`, *optional*): - The code of the target language. If not specified, the last `tgt_lang` specified will be used. - kwargs (*optional*): - Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the - tokenizer. + Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -75,7 +112,6 @@ def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwarg `None`). - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`. """ - sampling_rate = kwargs.pop("sampling_rate", None) if text is None and audios is None: raise ValueError("You have to specify either text or audios. Both cannot be none.") @@ -83,18 +119,21 @@ def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwarg raise ValueError( "Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another." ) - elif text is not None: - if tgt_lang is not None: - self.tokenizer.tgt_lang = tgt_lang - if src_lang is not None: - self.tokenizer.src_lang = src_lang - encoding = self.tokenizer(text, **kwargs) - - return encoding - - else: - encoding = self.feature_extractor(audios, sampling_rate=sampling_rate, **kwargs) - return encoding + + output_kwargs = self._merge_kwargs( + SeamlessM4TProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + data = {} + if text is not None: + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) + if audios is not None: + audio_features = self.feature_extractor(audios, **output_kwargs["audio_kwargs"]) + data.update(audio_features) + return BatchFeature(data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/siglip/processing_siglip.py b/src/transformers/models/siglip/processing_siglip.py index 655fb4d4f78ab0..f8f3e8f9eaff49 100644 --- a/src/transformers/models/siglip/processing_siglip.py +++ b/src/transformers/models/siglip/processing_siglip.py @@ -16,13 +16,30 @@ Image/Text processor class for SigLIP. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class SiglipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } class SiglipProcessor(ProcessorMixin): @@ -48,12 +65,11 @@ def __init__(self, image_processor, tokenizer): def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[SiglipProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -63,33 +79,13 @@ def __call__( of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -104,21 +100,20 @@ def __call__( if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - if text is not None: - encoding = self.tokenizer( - text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length - ) + output_kwargs = self._merge_kwargs( + SiglipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/speecht5/processing_speecht5.py b/src/transformers/models/speecht5/processing_speecht5.py index 468a0c1d89ab21..a9ce1bf4dea72b 100644 --- a/src/transformers/models/speecht5/processing_speecht5.py +++ b/src/transformers/models/speecht5/processing_speecht5.py @@ -14,7 +14,32 @@ # limitations under the License. """Speech processor class for SpeechT5.""" -from ...processing_utils import ProcessorMixin +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class SpeechT5ProcessorAudioKwargs(AudioKwargs, total=False): + audio_target: Optional[AudioInput] + + +class SpeechT5ProcessorTextKwargs(TextKwargs, total=False): + text_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + + +class SpeechT5ProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: SpeechT5ProcessorAudioKwargs + text_kwargs: SpeechT5ProcessorTextKwargs + _defaults = {} class SpeechT5Processor(ProcessorMixin): @@ -31,13 +56,21 @@ class SpeechT5Processor(ProcessorMixin): An instance of [`SpeechT5Tokenizer`]. The tokenizer is a required input. """ + attributes = ["feature_extractor", "tokenizer"] feature_extractor_class = "SpeechT5FeatureExtractor" tokenizer_class = "SpeechT5Tokenizer" def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) - def __call__(self, *args, **kwargs): + def __call__( + self, + audio: Optional[AudioInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images=None, + videos=None, + **kwargs: Unpack[SpeechT5ProcessorKwargs], + ) -> BatchFeature: """ Processes audio and text input, as well as audio and text targets. @@ -60,12 +93,34 @@ def __call__(self, *args, **kwargs): - `audio` and `text_target` Please refer to the docstring of the above two methods for more information. + + Args: + audio (`AudioInput`, *optional*): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + - **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`. + - **attention_mask** -- List of indices specifying which timestamps should be attended to by the model when `audio` is not `None`. + When only `text` is specified, returns the token attention mask. + - **labels** -- List of token ids to be fed to a model. Returned when both `text` and `audio` are not `None`. + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`. """ - audio = kwargs.pop("audio", None) - text = kwargs.pop("text", None) - text_target = kwargs.pop("text_target", None) - audio_target = kwargs.pop("audio_target", None) - sampling_rate = kwargs.pop("sampling_rate", None) + + output_kwargs = self._merge_kwargs( + SpeechT5ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + audio_target = output_kwargs["audio_kwargs"].pop("audio_target", None) + text_target = output_kwargs["text_kwargs"].pop("text_target", None) if audio is not None and text is not None: raise ValueError( @@ -80,33 +135,33 @@ def __call__(self, *args, **kwargs): "You need to specify either an `audio`, `audio_target`, `text`, or `text_target` input to process." ) + input_data = {} if audio is not None: - inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + input_data.update(audio_features) elif text is not None: - inputs = self.tokenizer(text, **kwargs) - else: - inputs = None + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + input_data.update(text_features) + target_data = {} if audio_target is not None: - targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs) - labels = targets["input_values"] + target_audio_features = self.feature_extractor(audio_target=audio_target, **output_kwargs["audio_kwargs"]) + target_data.update(target_audio_features) elif text_target is not None: - targets = self.tokenizer(text_target, **kwargs) - labels = targets["input_ids"] - else: - targets = None + target_text_features = self.tokenizer(text_target, **output_kwargs["text_kwargs"]) + target_data.update(target_text_features) - if inputs is None: - return targets + if not input_data: + return BatchFeature(target_data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) - if targets is not None: - inputs["labels"] = labels - - decoder_attention_mask = targets.get("attention_mask") - if decoder_attention_mask is not None: - inputs["decoder_attention_mask"] = decoder_attention_mask + if target_data: + input_data["labels"] = ( + target_data["input_values"] if audio_target is not None else target_data["input_ids"] + ) + if (decoder_attention_mask := target_data.get("attention_mask")) is not None: + input_data["decoder_attention_mask"] = decoder_attention_mask - return inputs + return BatchFeature(input_data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def pad(self, *args, **kwargs): """ diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index 100ec133e8b026..60588d213477f3 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -32,6 +32,7 @@ ChannelDimension, ImageInput, PILImageResampling, + VideoInput, get_image_size, is_valid_image, to_numpy_array, @@ -48,16 +49,25 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.vivit.image_processing_vivit.make_batched -def make_batched(videos) -> List[List[ImageInput]]: +# Copied from transformers.models.vivit.image_processing_vivit.make_batched_videos +def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - return [videos] + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: + return [videos] + elif len(videos[0].shape) == 4: + return videos elif is_valid_image(videos): - return [[videos]] + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") @@ -443,7 +453,7 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - videos = make_batched(videos) + videos = make_batched_videos(videos) videos = [ np.array( diff --git a/src/transformers/models/tvp/processing_tvp.py b/src/transformers/models/tvp/processing_tvp.py index eb8aabfdade3ed..7ce29d9e9e1a53 100644 --- a/src/transformers/models/tvp/processing_tvp.py +++ b/src/transformers/models/tvp/processing_tvp.py @@ -16,8 +16,35 @@ Processor class for TVP. """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +import sys +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import VideoInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class TvpTextKwargs(TextKwargs, total=False): + pad_to_max_length: bool + + +class TvpProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: TvpTextKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + "truncation": True, + "pad_to_max_length": True, + "return_token_type_ids": False, + }, + } class TvpProcessor(ProcessorMixin): @@ -46,7 +73,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + videos: Optional[VideoInput] = None, + images=None, + audio=None, + **kwargs: Unpack[TvpProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -65,16 +99,8 @@ def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of channels. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -83,30 +109,26 @@ def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): - **pixel_values** -- Pixel values to be fed to a model. Returned when `videos` is not `None`. """ - max_text_length = kwargs.pop("max_text_length", None) + if "max_text_length" in kwargs: + kwargs["max_length"] = kwargs.pop("max_text_length") if text is None and videos is None: raise ValueError("You have to specify either text or videos. Both cannot be none.") - encoding = {} - if text is not None: - textual_input = self.tokenizer.batch_encode_plus( - text, - truncation=True, - padding="max_length", - max_length=max_text_length, - pad_to_max_length=True, - return_tensors=return_tensors, - return_token_type_ids=False, - **kwargs, - ) - encoding.update(textual_input) + output_kwargs = self._merge_kwargs( + TvpProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) if videos is not None: - image_features = self.image_processor(videos, return_tensors=return_tensors, **kwargs) - encoding.update(image_features) - - return BatchEncoding(data=encoding, tensor_type=return_tensors) + video_features = self.image_processor(videos, **output_kwargs["videos_kwargs"]) + data.update(video_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 3e77110c7d45a8..2472b9bdd85417 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -50,18 +50,25 @@ import PIL +# Copied from transformers.models.vivit.image_processing_vivit.make_batched_videos def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], PIL.Image.Image): + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: return [videos] elif len(videos[0].shape) == 4: - return [list(video) for video in videos] + return videos - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] + elif is_valid_image(videos): + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index a06913d7acf760..49c103fabe9729 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -16,18 +16,36 @@ Processor class for VideoLlava. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack logger = logging.get_logger(__name__) +class VideoLlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + class VideoLlavaProcessor(ProcessorMixin): r""" Constructs a VideoLlava processor which wraps a VideoLlava image processor and a Llava tokenizer into a single processor. @@ -77,13 +95,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - videos: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + audio=None, + **kwargs: Unpack[VideoLlavaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -93,38 +109,18 @@ def __call__( of the above two methods for more information. Args: - text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + images (`ImageInput`, *optional*): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + videos (`VideoInput`, *optional*): Video frames to preprocess. Expects a single or batch of video frames in NumPy array or PyTorch tensor. Each video should be of shape (T, C, H, W), where T is number of frames, C is number of channels, H and W are image height and width. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -135,9 +131,17 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + output_kwargs = self._merge_kwargs( + VideoLlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # Temporary fix for "paddding_side" in init_kwargs + _ = output_kwargs["text_kwargs"].pop("padding_side", None) + data = {} if images is not None or videos is not None: - encoded_images = self.image_processor(images=images, videos=videos, return_tensors=return_tensors) + encoded_images = self.image_processor(images=images, videos=videos, **output_kwargs["images_kwargs"]) data.update(encoded_images) if isinstance(text, str): @@ -174,16 +178,10 @@ def __call__( sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) data.update(text_inputs) - return BatchFeature(data=data) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py index 413589523aa675..7355e356196ca4 100644 --- a/src/transformers/models/videomae/image_processing_videomae.py +++ b/src/transformers/models/videomae/image_processing_videomae.py @@ -30,6 +30,7 @@ ChannelDimension, ImageInput, PILImageResampling, + VideoInput, infer_channel_dimension_format, is_scaled_image, is_valid_image, @@ -47,15 +48,25 @@ logger = logging.get_logger(__name__) -def make_batched(videos) -> List[List[ImageInput]]: +# Copied from transformers.models.vivit.image_processing_vivit.make_batched_videos +def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - return [videos] + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: + return [videos] + elif len(videos[0].shape) == 4: + return videos elif is_valid_image(videos): - return [[videos]] + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") @@ -317,7 +328,7 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - videos = make_batched(videos) + videos = make_batched_videos(videos) videos = [ [ diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index 66ffeb816fec5e..f2c3529218e257 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -112,8 +112,8 @@ def get_resize_output_image_size( new_width = scale * new_width new_height, new_width = int(new_height + 0.5), int(new_width + 0.5) - new_height = new_height // size_divisor * size_divisor - new_width = new_width // size_divisor * size_divisor + new_height = max(1, new_height // size_divisor) * size_divisor + new_width = max(1, new_width // size_divisor) * size_divisor return new_height, new_width @@ -236,9 +236,7 @@ def resize( The channel dimension format of the input image. If not provided, it will be inferred. """ size = get_size_dict(size, default_to_square=False) - if "shortest_edge" not in size: - raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") - shorter = size["shortest_edge"] + shorter = size["shortest_edge"] if "shortest_edge" in size else min(size["height"], size["width"]) longer = int(1333 / 800 * shorter) output_size = get_resize_output_image_size( image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format diff --git a/src/transformers/models/vilt/processing_vilt.py b/src/transformers/models/vilt/processing_vilt.py index 0ccb884ea00c9d..562e5a3f94a955 100644 --- a/src/transformers/models/vilt/processing_vilt.py +++ b/src/transformers/models/vilt/processing_vilt.py @@ -16,12 +16,35 @@ Processor class for ViLT. """ +import sys import warnings from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class ViltProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + } class ViltProcessor(ProcessorMixin): @@ -63,53 +86,50 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): def __call__( self, - images, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: + images: ImageInput, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ViltProcessorKwargs], + ) -> BatchFeature: """ This method uses [`ViltImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. - Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_mask** -- Mask for the pixel values. Returned when `images` is not `None`. """ - encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + output_kwargs = self._merge_kwargs( + ViltProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - # add pixel_values + pixel_mask - encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) - encoding.update(encoding_image_processor) - return encoding + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) + if images is not None: + images_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(images_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index 5f251bbd1b95b9..b50b09089f5114 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -34,6 +34,7 @@ ChannelDimension, ImageInput, PILImageResampling, + VideoInput, infer_channel_dimension_format, is_scaled_image, is_valid_image, @@ -50,15 +51,24 @@ logger = logging.get_logger(__name__) -def make_batched(videos) -> List[List[ImageInput]]: +def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos + if isinstance(videos[0][0], PIL.Image.Image) or len(videos[0][0].shape) == 3: + return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - return [videos] + if isinstance(videos[0], PIL.Image.Image) or len(videos[0].shape) == 3: + return [videos] + elif len(videos[0].shape) == 4: + return videos elif is_valid_image(videos): - return [[videos]] + if isinstance(videos, PIL.Image.Image) or len(videos.shape) == 3: + return [[videos]] + elif len(videos.shape) == 4: + return [videos] + elif len(videos.shape) == 5: + return videos raise ValueError(f"Could not make batched video from {videos}") @@ -375,7 +385,7 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - videos = make_batched(videos) + videos = make_batched_videos(videos) videos = [ [ diff --git a/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py index d24c672007d734..b6c45167fe99e8 100644 --- a/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py @@ -16,13 +16,27 @@ Speech processor class for Wav2Vec2-BERT """ +import sys import warnings +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class Wav2Vec2BertProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + class Wav2Vec2BertProcessor(ProcessorMixin): r""" Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single @@ -38,6 +52,7 @@ class Wav2Vec2BertProcessor(ProcessorMixin): An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input. """ + attributes = ["feature_extractor", "tokenizer"] feature_extractor_class = "SeamlessM4TFeatureExtractor" tokenizer_class = "AutoTokenizer" @@ -63,7 +78,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) - def __call__(self, audio=None, text=None, **kwargs): + def __call__( + self, + audio: Optional[AudioInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images=None, + videos=None, + **kwargs: Unpack[Wav2Vec2BertProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio` and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not @@ -71,17 +93,15 @@ def __call__(self, audio=None, text=None, **kwargs): PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + audio (`AudioInput`, *optional*): The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the sample length of the audio. - kwargs (*optional*): - Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the - tokenizer. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`. @@ -91,23 +111,26 @@ def __call__(self, audio=None, text=None, **kwargs): - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`. """ - sampling_rate = kwargs.pop("sampling_rate", None) - if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") + output_kwargs = self._merge_kwargs( + Wav2Vec2BertProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + data = {} if audio is not None: - inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs) + audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + data.update(audio_features) if text is not None: - encodings = self.tokenizer(text, **kwargs) - - if text is None: - return inputs - elif audio is None: - return encodings - else: - inputs["labels"] = encodings["input_ids"] - return inputs + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if audio is not None: + data["labels"] = text_features["input_ids"] + else: + data.update(text_features) + return BatchFeature(data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def pad(self, input_features=None, labels=None, **kwargs): """ diff --git a/src/transformers/models/x_clip/processing_x_clip.py b/src/transformers/models/x_clip/processing_x_clip.py index a11aeb18dc4f59..f722ef37d498a8 100644 --- a/src/transformers/models/x_clip/processing_x_clip.py +++ b/src/transformers/models/x_clip/processing_x_clip.py @@ -16,10 +16,24 @@ Image/Text processor class for XCLIP """ +import sys import warnings +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...feature_extraction_utils import BatchFeature +from ...image_utils import VideoInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class XCLIPProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class XCLIPProcessor(ProcessorMixin): @@ -59,7 +73,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + videos: Optional[VideoInput] = None, + images=None, + audio=None, + **kwargs: Unpack[XCLIPProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode @@ -78,16 +99,8 @@ def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of channels. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -99,19 +112,20 @@ def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): if text is None and videos is None: raise ValueError("You have to specify either text or videos. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + output_kwargs = self._merge_kwargs( + XCLIPProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + data.update(text_features) if videos is not None: - image_features = self.image_processor(videos, return_tensors=return_tensors, **kwargs) - - if text is not None and videos is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + video_features = self.image_processor(videos, **output_kwargs["images_kwargs"]) + data.update(video_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ diff --git a/tests/models/altclip/test_processor_altclip.py b/tests/models/altclip/test_processor_altclip.py new file mode 100644 index 00000000000000..86a84ae9ab8bc6 --- /dev/null +++ b/tests/models/altclip/test_processor_altclip.py @@ -0,0 +1,16 @@ +import tempfile +import unittest + +from transformers import AltCLIPProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class AltCLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "BAAI/AltCLIP" + processor_class = AltCLIPProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py new file mode 100644 index 00000000000000..1efeaa5339d304 --- /dev/null +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -0,0 +1,33 @@ +import tempfile +import unittest + +from transformers import ChameleonProcessor +from transformers.models.auto.processing_auto import processor_class_from_name + +from ...test_processing_common import ProcessorTesterMixin + + +class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "leloy/Anole-7b-v0.1-hf" + processor_class = ChameleonProcessor + + def get_component(self, attribute, **kwargs): + assert attribute in self.processor_class.attributes + component_class_name = getattr(self.processor_class, f"{attribute}_class") + if isinstance(component_class_name, tuple): + if "_fast" in component_class_name[0]: + component_class_name = component_class_name[0] + else: + component_class_name = component_class_name[1] + + component_class = processor_class_from_name(component_class_name) + component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa + if attribute == "tokenizer" and not component.pad_token: + component.pad_token = "[TEST_PAD]" + + return component + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/clap/test_processor_clap.py b/tests/models/clap/test_processor_clap.py index 49e9972ea02e22..06153613d5a4fa 100644 --- a/tests/models/clap/test_processor_clap.py +++ b/tests/models/clap/test_processor_clap.py @@ -19,21 +19,32 @@ from transformers import ClapFeatureExtractor, ClapProcessor, RobertaTokenizer, RobertaTokenizerFast from transformers.testing_utils import require_sentencepiece, require_torchaudio +from ...test_processing_common import ProcessorTesterMixin from .test_feature_extraction_clap import floats_list @require_torchaudio @require_sentencepiece -class ClapProcessorTest(unittest.TestCase): +class ClapProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "laion/clap-htsat-unfused" + processor_class = ClapProcessor + def setUp(self): - self.checkpoint = "laion/clap-htsat-unfused" self.tmpdirname = tempfile.mkdtemp() + processor = ClapProcessor.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) + + def get_component(self, attribute, **kwargs): + assert attribute in self.processor_class.attributes + if attribute == "tokenizer": + return self.get_tokenizer(**kwargs) + return super().get_component(attribute, **kwargs) def get_tokenizer(self, **kwargs): - return RobertaTokenizer.from_pretrained(self.checkpoint, **kwargs) + return RobertaTokenizer.from_pretrained(self.from_pretrained_id, **kwargs) def get_feature_extractor(self, **kwargs): - return ClapFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + return ClapFeatureExtractor.from_pretrained(self.from_pretrained_id, **kwargs) def tearDown(self): shutil.rmtree(self.tmpdirname) diff --git a/tests/models/clvp/test_processor_clvp.py b/tests/models/clvp/test_processor_clvp.py index f751ab92d03d95..becd0e3bc7e24e 100644 --- a/tests/models/clvp/test_processor_clvp.py +++ b/tests/models/clvp/test_processor_clvp.py @@ -21,14 +21,19 @@ from transformers import ClvpFeatureExtractor, ClvpProcessor, ClvpTokenizer from transformers.testing_utils import require_torch +from ...test_processing_common import ProcessorTesterMixin from .test_feature_extraction_clvp import floats_list @require_torch -class ClvpProcessorTest(unittest.TestCase): +class ClvpProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "susnato/clvp_dev" + processor_class = ClvpProcessor + def setUp(self): - self.checkpoint = "susnato/clvp_dev" self.tmpdirname = tempfile.mkdtemp() + processor = ClvpProcessor.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) def tearDown(self): super().tearDown() @@ -37,11 +42,11 @@ def tearDown(self): # Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.get_tokenizer with Whisper->Clvp def get_tokenizer(self, **kwargs): - return ClvpTokenizer.from_pretrained(self.checkpoint, **kwargs) + return ClvpTokenizer.from_pretrained(self.from_pretrained_id, **kwargs) # Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.get_feature_extractor with Whisper->Clvp def get_feature_extractor(self, **kwargs): - return ClvpFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + return ClvpFeatureExtractor.from_pretrained(self.from_pretrained_id, **kwargs) # Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.test_save_load_pretrained_default with Whisper->Clvp def test_save_load_pretrained_default(self): diff --git a/tests/models/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py index a83e459153d532..56a52ee21c7b07 100644 --- a/tests/models/flava/test_processor_flava.py +++ b/tests/models/flava/test_processor_flava.py @@ -22,16 +22,18 @@ import numpy as np import pytest -from transformers import BertTokenizer, BertTokenizerFast +from transformers import BertTokenizer, BertTokenizerFast, FlavaProcessor from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): from PIL import Image - from transformers import FlavaImageProcessor, FlavaProcessor + from transformers import FlavaImageProcessor from transformers.models.flava.image_processing_flava import ( FLAVA_CODEBOOK_MEAN, FLAVA_CODEBOOK_STD, @@ -41,7 +43,9 @@ @require_vision -class FlavaProcessorTest(unittest.TestCase): +class FlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = FlavaProcessor + def setUp(self): self.tmpdirname = tempfile.mkdtemp() diff --git a/tests/models/git/test_processor_git.py b/tests/models/git/test_processor_git.py index 95e436d8e4f526..d66260bc57483a 100644 --- a/tests/models/git/test_processor_git.py +++ b/tests/models/git/test_processor_git.py @@ -21,6 +21,8 @@ from transformers.testing_utils import require_vision from transformers.utils import is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): from PIL import Image @@ -29,7 +31,9 @@ @require_vision -class GitProcessorTest(unittest.TestCase): +class GitProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = GitProcessor + def setUp(self): self.tmpdirname = tempfile.mkdtemp() diff --git a/tests/models/instructblipvideo/test_processor_instructblipvideo.py b/tests/models/instructblipvideo/test_processor_instructblipvideo.py new file mode 100644 index 00000000000000..9442a429944226 --- /dev/null +++ b/tests/models/instructblipvideo/test_processor_instructblipvideo.py @@ -0,0 +1,22 @@ +import tempfile +import unittest + +from transformers import InstructBlipVideoProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "Salesforce/instructblip-vicuna-7b" + processor_class = InstructBlipVideoProcessor + videos_data_arg_name = "pixel_values" + + def prepare_components(self): + components = super().prepare_components() + components["qformer_tokenizer"] = components["tokenizer"] + return components + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/llava_next_video/test_processor_llava_next_video.py b/tests/models/llava_next_video/test_processor_llava_next_video.py new file mode 100644 index 00000000000000..9cd4615d572547 --- /dev/null +++ b/tests/models/llava_next_video/test_processor_llava_next_video.py @@ -0,0 +1,16 @@ +import tempfile +import unittest + +from transformers import LlavaNextVideoProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" + processor_class = LlavaNextVideoProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/mgp_str/test_processor_mgp_str.py b/tests/models/mgp_str/test_processor_mgp_str.py index 6a028a28424d61..a5322aa5d31435 100644 --- a/tests/models/mgp_str/test_processor_mgp_str.py +++ b/tests/models/mgp_str/test_processor_mgp_str.py @@ -20,29 +20,30 @@ import tempfile import unittest -import numpy as np import pytest -from transformers import MgpstrTokenizer +from transformers import MgpstrProcessor, MgpstrTokenizer from transformers.models.mgp_str.tokenization_mgp_str import VOCAB_FILES_NAMES from transformers.testing_utils import require_torch, require_vision from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_torch_available(): import torch if is_vision_available(): - from PIL import Image - - from transformers import MgpstrProcessor, ViTImageProcessor + from transformers import ViTImageProcessor @require_torch @require_vision -class MgpstrProcessorTest(unittest.TestCase): +class MgpstrProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = MgpstrProcessor image_processing_class = ViTImageProcessor if is_vision_available() else None + text_data_arg_name = "labels" @property def image_processor_dict(self): @@ -79,15 +80,6 @@ def get_image_processor(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) - def prepare_image_inputs(self): - """This function prepares a list of PIL images.""" - - image_input = np.random.randint(255, size=(3, 30, 400), dtype=np.uint8) - - image_input = Image.fromarray(np.moveaxis(image_input, 0, -1)) - - return image_input - def test_save_load_pretrained_default(self): tokenizer = self.get_tokenizer() image_processor = self.get_image_processor() diff --git a/tests/models/musicgen_melody/test_processor_musicgen_melody.py b/tests/models/musicgen_melody/test_processor_musicgen_melody.py index e00f31c495990f..ec65e3f1ffeed1 100644 --- a/tests/models/musicgen_melody/test_processor_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_processor_musicgen_melody.py @@ -20,13 +20,15 @@ import numpy as np -from transformers import T5Tokenizer, T5TokenizerFast +from transformers import MusicgenMelodyProcessor, T5Tokenizer, T5TokenizerFast from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio from transformers.utils.import_utils import is_torchaudio_available +from ...test_processing_common import ProcessorTesterMixin + if is_torchaudio_available(): - from transformers import MusicgenMelodyFeatureExtractor, MusicgenMelodyProcessor + from transformers import MusicgenMelodyFeatureExtractor global_rng = random.Random() @@ -51,17 +53,20 @@ def floats_list(shape, scale=1.0, rng=None, name=None): @require_sentencepiece @require_torchaudio # Copied from tests.models.musicgen.test_processing_musicgen.MusicgenProcessorTest with Musicgen->MusicgenMelody, Encodec->MusicgenMelody, padding_mask->attention_mask, input_values->input_features -class MusicgenMelodyProcessorTest(unittest.TestCase): +class MusicgenMelodyProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "facebook/musicgen-melody" + processor_class = MusicgenMelodyProcessor + def setUp(self): - # Ignore copy - self.checkpoint = "facebook/musicgen-melody" self.tmpdirname = tempfile.mkdtemp() + processor = MusicgenMelodyProcessor.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): - return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs) + return T5Tokenizer.from_pretrained(self.from_pretrained_id, **kwargs) def get_feature_extractor(self, **kwargs): - return MusicgenMelodyFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + return MusicgenMelodyFeatureExtractor.from_pretrained(self.from_pretrained_id, **kwargs) def tearDown(self): shutil.rmtree(self.tmpdirname) diff --git a/tests/models/qwen2_audio/test_processor_qwen2_audio.py b/tests/models/qwen2_audio/test_processor_qwen2_audio.py index d324a7d9105091..760765ac8d64e2 100644 --- a/tests/models/qwen2_audio/test_processor_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_processor_qwen2_audio.py @@ -17,22 +17,30 @@ from transformers import AutoProcessor, AutoTokenizer, Qwen2AudioProcessor, WhisperFeatureExtractor from transformers.testing_utils import require_torch, require_torchaudio +from ...test_processing_common import ProcessorTesterMixin + @require_torch @require_torchaudio -class Qwen2AudioProcessorTest(unittest.TestCase): +class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "Qwen/Qwen2-Audio-7B-Instruct" + processor_class = Qwen2AudioProcessor + def setUp(self): - self.checkpoint = "Qwen/Qwen2-Audio-7B-Instruct" self.tmpdirname = tempfile.mkdtemp() + tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id) + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor = Qwen2AudioProcessor(tokenizer=tokenizer, feature_extractor=processor.feature_extractor) + processor.save_pretrained(self.tmpdirname) def test_can_load_various_tokenizers(self): - processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint) - tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) + processor = Qwen2AudioProcessor.from_pretrained(self.from_pretrained_id) + tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id) self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__) def test_save_load_pretrained_default(self): - tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) - processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint) + tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id) + processor = Qwen2AudioProcessor.from_pretrained(self.from_pretrained_id) feature_extractor = processor.feature_extractor processor = Qwen2AudioProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) @@ -45,8 +53,8 @@ def test_save_load_pretrained_default(self): self.assertIsInstance(processor.feature_extractor, WhisperFeatureExtractor) def test_tokenizer_integration(self): - slow_tokenizer = AutoTokenizer.from_pretrained(self.checkpoint, use_fast=False) - fast_tokenizer = AutoTokenizer.from_pretrained(self.checkpoint, from_slow=True, legacy=False) + slow_tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id, use_fast=False) + fast_tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id, from_slow=True, legacy=False) prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>\nWhat is it in this audio?<|im_end|><|im_start|>assistant\n" EXPECTED_OUTPUT = [ @@ -82,7 +90,7 @@ def test_tokenizer_integration(self): self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) def test_chat_template(self): - processor = AutoProcessor.from_pretrained(self.checkpoint) + processor = AutoProcessor.from_pretrained(self.from_pretrained_id) expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass shattering.<|im_end|>\n<|im_start|>user\nAudio 2: <|audio_bos|><|AUDIO|><|audio_eos|>\nHow about this one?<|im_end|>\n<|im_start|>assistant\n" messages = [ diff --git a/tests/models/seamless_m4t/test_processor_seamless_m4t.py b/tests/models/seamless_m4t/test_processor_seamless_m4t.py index 7beefb16bda7ea..701f7002bb3d2c 100644 --- a/tests/models/seamless_m4t/test_processor_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_processor_seamless_m4t.py @@ -21,26 +21,49 @@ SeamlessM4TTokenizer, SeamlessM4TTokenizerFast, ) -from transformers.testing_utils import require_torch +from transformers.testing_utils import require_torch, require_vision +from ...test_processing_common import ProcessorTesterMixin from .test_feature_extraction_seamless_m4t import floats_list @require_torch -class SeamlessM4TProcessorTest(unittest.TestCase): +class SeamlessM4TProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "facebook/hf-seamless-m4t-medium" + processor_class = SeamlessM4TProcessor + def setUp(self): - self.checkpoint = "facebook/hf-seamless-m4t-medium" self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): - return SeamlessM4TTokenizer.from_pretrained(self.checkpoint, **kwargs) + return SeamlessM4TTokenizer.from_pretrained(self.from_pretrained_id, **kwargs) def get_feature_extractor(self, **kwargs): - return SeamlessM4TFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + return SeamlessM4TFeatureExtractor.from_pretrained(self.from_pretrained_id, **kwargs) def tearDown(self): shutil.rmtree(self.tmpdirname) + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "tokenizer" not in self.processor_class.attributes: + self.skipTest(f"tokenizer attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["tokenizer"] = self.get_component( + "tokenizer", max_length=117, padding="max_length", pad_to_multiple_of=1 + ) + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() if "image_processor" in self.processor_class.attributes else None + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 117) + def test_save_load_pretrained_default(self): tokenizer = self.get_tokenizer() feature_extractor = self.get_feature_extractor() diff --git a/tests/models/siglip/test_processor_siglip.py b/tests/models/siglip/test_processor_siglip.py new file mode 100644 index 00000000000000..608ff70539a218 --- /dev/null +++ b/tests/models/siglip/test_processor_siglip.py @@ -0,0 +1,16 @@ +import tempfile +import unittest + +from transformers import SiglipProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class SiglipProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "google/siglip-base-patch16-224" + processor_class = SiglipProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/speecht5/test_processor_speecht5.py b/tests/models/speecht5/test_processor_speecht5.py index 97d3842f105ad6..cb4e52e52c0764 100644 --- a/tests/models/speecht5/test_processor_speecht5.py +++ b/tests/models/speecht5/test_processor_speecht5.py @@ -19,14 +19,16 @@ import tempfile import unittest -from transformers import is_speech_available, is_torch_available +from transformers import SpeechT5Processor, is_speech_available, is_torch_available from transformers.models.speecht5 import SpeechT5Tokenizer -from transformers.testing_utils import get_tests_dir, require_torch +from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio from transformers.utils import FEATURE_EXTRACTOR_NAME +from ...test_processing_common import ProcessorTesterMixin + if is_speech_available() and is_torch_available(): - from transformers import SpeechT5FeatureExtractor, SpeechT5Processor + from transformers import SpeechT5FeatureExtractor from .test_feature_extraction_speecht5 import floats_list @@ -35,7 +37,10 @@ @require_torch -class SpeechT5ProcessorTest(unittest.TestCase): +@require_torchaudio +class SpeechT5ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = SpeechT5Processor + def setUp(self): self.tmpdirname = tempfile.mkdtemp() diff --git a/tests/models/tvp/test_processor_tvp.py b/tests/models/tvp/test_processor_tvp.py new file mode 100644 index 00000000000000..40d700e0beea15 --- /dev/null +++ b/tests/models/tvp/test_processor_tvp.py @@ -0,0 +1,75 @@ +import inspect +import tempfile +import unittest + +from transformers import TvpProcessor +from transformers.testing_utils import require_torch, require_vision + +from ...test_processing_common import ProcessorTesterMixin + + +class TvpProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "Jiqing/tiny-random-tvp" + processor_class = TvpProcessor + videos_data_arg_name = "pixel_values" + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) + + @require_torch + @require_vision + def test_video_processor_defaults_preserved_by_kwargs(self): + if "video_processor" not in self.processor_class.attributes and ( + "videos" not in inspect.signature(self.processor_class.__call__).parameters + or inspect.signature(self.processor_class.__call__).parameters["videos"].annotation == inspect._empty + ): + self.skipTest(f"video_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", size=(234, 234), crop_size=(234, 234), do_pad=False + ) + if "video_processor" in self.processor_class.attributes: + processor_components["video_processor"] = self.get_component( + "video_processor", size=(234, 234), crop_size=(234, 234), do_pad=False + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + video_input = self.prepare_video_inputs() + + inputs = processor(text=input_str, videos=video_input, return_tensors="pt") + self.assertEqual(inputs[self.videos_data_arg_name].shape[-1], 234) + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + self.skipTest("TVP does not process images") + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + self.skipTest("TVP does not process images") + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + self.skipTest("TVP does not process images") + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + self.skipTest("TVP does not process images") + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + self.skipTest("TVP does not process images") + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + self.skipTest("TVP does not process images") diff --git a/tests/models/video_llava/test_processor_video_llava.py b/tests/models/video_llava/test_processor_video_llava.py new file mode 100644 index 00000000000000..9ddc84a6bcb944 --- /dev/null +++ b/tests/models/video_llava/test_processor_video_llava.py @@ -0,0 +1,17 @@ +import tempfile +import unittest + +from transformers.models.video_llava.processing_video_llava import VideoLlavaProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class VideoLlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "LanguageBind/Video-LLaVA-7B-hf" + processor_class = VideoLlavaProcessor + images_data_arg_name = "pixel_values_images" + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/models/vilt/test_processor_vilt.py b/tests/models/vilt/test_processor_vilt.py new file mode 100644 index 00000000000000..0ae6a5256d1b32 --- /dev/null +++ b/tests/models/vilt/test_processor_vilt.py @@ -0,0 +1,178 @@ +import tempfile +import unittest + +from transformers import ViltProcessor +from transformers.testing_utils import require_torch, require_vision + +from ...test_processing_common import ProcessorTesterMixin + + +class ViltProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "dandelin/vilt-b32-mlm" + processor_class = ViltProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", size=(234, 234), crop_size=(234, 234), size_divisor=32 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 384) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component("image_processor", size=(234, 234)) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 352) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor( + text=input_str, images=image_input, size=[224, 224], crop_size=(224, 224), return_tensors="pt" + ) + + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 352) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 352) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 352) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + # VILT resizes images to dims divisible by size_divisor + vilt_compatible_image_size = (32, 352) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], vilt_compatible_image_size[-1]) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), len(inputs[self.text_data_arg_name][1])) diff --git a/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py index b6b1506f5e4d68..f5811d3d7f0b01 100644 --- a/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py @@ -24,11 +24,14 @@ from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor from transformers.utils import FEATURE_EXTRACTOR_NAME +from ...test_processing_common import ProcessorTesterMixin from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list # Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor -class Wav2Vec2BertProcessorTest(unittest.TestCase): +class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Wav2Vec2BertProcessor + def setUp(self): vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ") vocab_tokens = dict(zip(vocab, range(len(vocab)))) @@ -56,6 +59,12 @@ def setUp(self): with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(feature_extractor_map) + "\n") + def get_component(self, attribute, **kwargs): + assert attribute in self.processor_class.attributes + if attribute == "tokenizer": + return self.get_tokenizer(**kwargs) + return super().get_component(attribute, **kwargs) + def get_tokenizer(self, **kwargs_init): kwargs = self.add_kwargs_tokens_map.copy() kwargs.update(kwargs_init) diff --git a/tests/models/x_clip/test_processor_x_clip.py b/tests/models/x_clip/test_processor_x_clip.py new file mode 100644 index 00000000000000..5b34855a67252a --- /dev/null +++ b/tests/models/x_clip/test_processor_x_clip.py @@ -0,0 +1,48 @@ +import tempfile +import unittest + +from transformers import XCLIPProcessor +from transformers.testing_utils import require_torch, require_vision + +from ...test_processing_common import ProcessorTesterMixin + + +class XCLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "microsoft/xclip-base-patch32" + processor_class = XCLIPProcessor + videos_data_arg_name = "pixel_values" + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + self.skipTest("XCLIP does not process images") + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + self.skipTest("XCLIP does not process images") + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + self.skipTest("XCLIP does not process images") + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + self.skipTest("XCLIP does not process images") + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + self.skipTest("XCLIP does not process images") + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + self.skipTest("XCLIP does not process images") diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index a30c6363b9d7ff..dc2deb6bc2199a 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -49,6 +49,9 @@ @require_torch class ProcessorTesterMixin: processor_class = None + text_data_arg_name = "input_ids" + images_data_arg_name = "pixel_values" + videos_data_arg_name = "pixel_values_videos" def prepare_processor_dict(self): return {} @@ -88,6 +91,10 @@ def prepare_image_inputs(self): image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] return image_inputs + @require_vision + def prepare_video_inputs(self): + return np.random.randint(255, size=(1, 4, 3, 30, 400), dtype=np.uint8) + def test_processor_to_json_string(self): processor = self.get_processor() obj = json.loads(processor.to_json_string()) @@ -125,80 +132,110 @@ def skip_processor_without_typed_kwargs(self, processor): @require_vision @require_torch def test_tokenizer_defaults_preserved_by_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + if "tokenizer" not in self.processor_class.attributes: + self.skipTest(f"tokenizer attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" - image_input = self.prepare_image_inputs() + image_input = self.prepare_image_inputs() if "image_processor" in self.processor_class.attributes else None inputs = processor(text=input_str, images=image_input, return_tensors="pt") - self.assertEqual(len(inputs["input_ids"][0]), 117) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 117) @require_torch @require_vision def test_image_processor_defaults_preserved_by_image_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", size=(234, 234), crop_size=(234, 234) + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input) - self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 234) + + @require_torch + @require_vision + def test_video_processor_defaults_preserved_by_kwargs(self): + if "video_processor" not in self.processor_class.attributes and ( + "videos" not in inspect.signature(self.processor_class.__call__).parameters + or inspect.signature(self.processor_class.__call__).parameters["videos"].annotation == inspect._empty + ): + self.skipTest(f"video_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", size=(234, 234), crop_size=(234, 234) + ) + if "video_processor" in self.processor_class.attributes: + processor_components["video_processor"] = self.get_component( + "video_processor", size=(234, 234), crop_size=(234, 234) + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + video_input = self.prepare_video_inputs() + + inputs = processor(text=input_str, videos=video_input, return_tensors="pt") + self.assertEqual(inputs[self.videos_data_arg_name].shape[-1], 234) @require_vision @require_torch def test_kwargs_overrides_default_tokenizer_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", padding="longest") + if "tokenizer" not in self.processor_class.attributes: + self.skipTest(f"tokenizer attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest") - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" - image_input = self.prepare_image_inputs() + image_input = self.prepare_image_inputs() if "image_processor" in self.processor_class.attributes else None inputs = processor( text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" ) - self.assertEqual(len(inputs["input_ids"][0]), 112) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 112) @require_torch @require_vision def test_kwargs_overrides_default_image_processor_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component("image_processor", size=(234, 234)) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input, size=[224, 224]) - self.assertEqual(len(inputs["pixel_values"][0][0]), 224) + inputs = processor( + text=input_str, images=image_input, size=[224, 224], crop_size=(224, 224), return_tensors="pt" + ) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 224) @require_torch @require_vision def test_unstructured_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" @@ -208,22 +245,21 @@ def test_unstructured_kwargs(self): images=image_input, return_tensors="pt", size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, padding="max_length", max_length=76, ) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 214) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) @require_torch @require_vision def test_unstructured_kwargs_batched(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = ["lower newer", "upper older longer string"] @@ -233,23 +269,21 @@ def test_unstructured_kwargs_batched(self): images=image_input, return_tensors="pt", size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, padding="longest", max_length=76, ) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 6) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 214) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), len(inputs[self.text_data_arg_name][1])) @require_torch @require_vision def test_doubly_passed_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = ["lower newer"] @@ -260,6 +294,8 @@ def test_doubly_passed_kwargs(self): images=image_input, images_kwargs={"size": {"height": 222, "width": 222}}, size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, + return_tensors="pt", ) @require_torch @@ -267,10 +303,8 @@ def test_doubly_passed_kwargs(self): def test_structured_kwargs_nested(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" @@ -279,42 +313,44 @@ def test_structured_kwargs_nested(self): # Define the kwargs for each modality all_kwargs = { "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, "text_kwargs": {"padding": "max_length", "max_length": 76}, } inputs = processor(text=input_str, images=image_input, **all_kwargs) self.skip_processor_without_typed_kwargs(processor) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 214) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) @require_torch @require_vision def test_structured_kwargs_nested_from_dict(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - - processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" image_input = self.prepare_image_inputs() # Define the kwargs for each modality all_kwargs = { "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, "text_kwargs": {"padding": "max_length", "max_length": 76}, } inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(inputs[self.images_data_arg_name].shape[-1], 214) + self.assertEqual(inputs[self.text_data_arg_name].shape[-1], 76) class MyProcessor(ProcessorMixin):