diff --git a/src/transformers/models/clipseg/processing_clipseg.py b/src/transformers/models/clipseg/processing_clipseg.py index f8eaca82334a22..121c1e2aa95442 100644 --- a/src/transformers/models/clipseg/processing_clipseg.py +++ b/src/transformers/models/clipseg/processing_clipseg.py @@ -16,10 +16,28 @@ Image/Text processor class for CLIPSeg """ +import sys import warnings +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class CLIPSegImagesKwargs(ImagesKwargs, total=False): + visual_prompt: Optional[ImageInput] + + +class CLIPSegProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: CLIPSegImagesKwargs + _defaults = {} class CLIPSegProcessor(ProcessorMixin): @@ -58,7 +76,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[CLIPSegProcessorKwargs], + ): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode @@ -79,14 +104,6 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: @@ -96,6 +113,29 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + + output_kwargs = self._merge_kwargs( + CLIPSegProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if output_kwargs["text_kwargs"].get("visual_prompt") is not None and audio is not None: + raise ValueError( + "You cannot provide `visual_prompt` as a positional argument and as a keyword argument at the same time." + "Please provide it only as a keyword argument (i.e. `visual_prompt=...`)." + ) + if "visual_prompt" not in output_kwargs["text_kwargs"]: + warnings.warn( + "No `visual_prompt` kwarg was detected. The use of `visual_prompt` as an argument without specifying it explicitely as `visual_prompt=` will be deprecated in future versions." + ) + # For backwards compatibility, we reuse `audio` as `visual_prompt` in case + # downstream users passed it as a positional argument + if audio is not None: + output_kwargs["text_kwargs"]["visual_prompt"] = audio + + visual_prompt = output_kwargs["text_kwargs"].pop("visual_prompt", None) + if text is None and visual_prompt is None and images is None: raise ValueError("You have to specify either text, visual prompt or images.") @@ -103,13 +143,13 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.") if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) if visual_prompt is not None: - prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs) + prompt_features = self.image_processor(visual_prompt, **output_kwargs["images_kwargs"]) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) if visual_prompt is not None and images is not None: encoding = { @@ -128,7 +168,9 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No } return encoding else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + return BatchEncoding( + data=dict(**image_features), tensor_type=output_kwargs["common_kwargs"]["return_tensors"] + ) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index f56f8186b07d73..1ec03dd661420b 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -17,26 +17,48 @@ """ import os +import sys from typing import List, Optional, Union from ...image_processing_utils import BatchFeature from ...image_utils import VideoInput -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin from ...tokenization_utils_base import ( AddedToken, BatchEncoding, - PaddingStrategy, PreTokenizedInput, TextInput, - TruncationStrategy, ) -from ...utils import TensorType, logging +from ...utils import logging from ..auto import AutoTokenizer +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + logger = logging.get_logger(__name__) +class InstructBlipVideoProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": None, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + } + + class InstructBlipVideoProcessor(ProcessorMixin): r""" Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single @@ -71,23 +93,11 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query def __call__( self, - images: VideoInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + images: Optional[VideoInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[InstructBlipVideoProcessorKwargs], ) -> BatchFeature: """ This method uses [`InstructBlipVideoImageProcessor.__call__`] method to prepare image(s) or video(s) for the model, and @@ -95,6 +105,12 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ + output_kwargs = self._merge_kwargs( + InstructBlipVideoProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + encoding = BatchFeature() if text is not None: @@ -105,21 +121,10 @@ def __call__( _text_encoding = self.tokenizer( text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=None, # required to concatenate below - **kwargs, + **{ + **output_kwargs["text_kwargs"], + "return_tensors": None, # required to concatenate below + }, ) # if we know how many query tokens, expand text inside processor. We need this hacky manipulation @@ -145,31 +150,14 @@ def __call__( ) # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) + text_encoding = BatchEncoding(text_encoding, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) encoding.update(text_encoding) - qformer_text_encoding = self.qformer_tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) + qformer_text_encoding = self.qformer_tokenizer(text=text, **output_kwargs["text_kwargs"]) encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") if images is not None: - image_encoding = self.image_processor(images, return_tensors=return_tensors) + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) encoding.update(image_encoding) return encoding diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index efbb193ba62a9f..e693ce265ef1e6 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -16,13 +16,20 @@ Processor class for LLaVa-NeXT-Video. """ +import sys from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack if TYPE_CHECKING: @@ -31,6 +38,17 @@ logger = logging.get_logger(__name__) +class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + class LlavaNextVideoProcessor(ProcessorMixin): r""" Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and @@ -88,12 +106,10 @@ def __init__( def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], - images: ImageInput = None, - videos: VideoInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + audio=None, + **kwargs: Unpack[LlavaNextVideoProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -115,26 +131,6 @@ def __call__( videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -145,13 +141,19 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + output_kwargs = self._merge_kwargs( + LlavaNextVideoProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: - image_inputs = self.image_processor(images, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} if videos is not None: - videos_inputs = self.video_processor(videos, return_tensors=return_tensors) + videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) else: videos_inputs = {} @@ -203,14 +205,7 @@ def __call__( sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) - print(text_inputs.keys()) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index 8b580ca5026618..0f605ac3f6adc3 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -16,15 +16,40 @@ Image/Text processor class for OWLv2 """ -from typing import List +import sys +import warnings +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...utils import is_flax_available, is_tf_available, is_torch_available +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class Owlv2ImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class Owlv2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Owlv2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "np", + }, + } + + class Owlv2Processor(ProcessorMixin): r""" Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into @@ -46,7 +71,14 @@ def __init__(self, image_processor, tokenizer, **kwargs): super().__init__(image_processor, tokenizer) # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OWLViT->OWLv2 - def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[Owlv2ProcessorKwargs], + ): """ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: @@ -67,12 +99,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt The query image to be prepared, one query image is expected per target image to be queried. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. @@ -81,6 +108,28 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + output_kwargs = self._merge_kwargs( + Owlv2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if output_kwargs["text_kwargs"].get("query_images") is not None and audio is not None: + raise ValueError( + "You cannot provide `query_images` as a positional argument and as a keyword argument at the same time." + "Please provide it only as a keyword argument (i.e. `query_images=...`)." + ) + if "query_images" not in output_kwargs["text_kwargs"]: + warnings.warn( + "No `query_images` kwarg was detected. The use of `query_images` as an argument without specifying it explicitely as `query_images=` will be deprecated in future versions." + ) + # For backwards compatibility, we reuse `audio` as `query_images` in case + # downstream users passed it as a positional argument + if audio is not None: + output_kwargs["text_kwargs"]["query_images"] = audio + + query_images = output_kwargs["text_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] if text is None and query_images is None and images is None: raise ValueError( @@ -89,7 +138,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if text is not None: if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): - encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] elif isinstance(text, List) and isinstance(text[0], List): encodings = [] @@ -102,7 +151,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if len(t) != max_num_queries: t = t + [" "] * (max_num_queries - len(t)) - encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings") @@ -138,13 +187,11 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if query_images is not None: encoding = BatchEncoding() - query_pixel_values = self.image_processor( - query_images, return_tensors=return_tensors, **kwargs - ).pixel_values + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values encoding["query_pixel_values"] = query_pixel_values if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 2c7d490104bdfc..7a512f2d1d7e02 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -16,16 +16,40 @@ Image/Text processor class for OWL-ViT """ +import sys import warnings -from typing import List +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...utils import is_flax_available, is_tf_available, is_torch_available +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class OwlViTImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class OwlViTProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: OwlViTImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "np", + }, + } + + class OwlViTProcessor(ProcessorMixin): r""" Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] @@ -61,7 +85,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs): + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[OwlViTProcessorKwargs], + ): """ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: @@ -82,12 +113,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt The query image to be prepared, one query image is expected per target image to be queried. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. @@ -97,6 +123,29 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + output_kwargs = self._merge_kwargs( + OwlViTProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if output_kwargs["text_kwargs"].get("query_images") is not None and audio is not None: + raise ValueError( + "You cannot provide `query_images` as a positional argument and as a keyword argument at the same time." + "Please provide it only as a keyword argument (i.e. `query_images=...`)." + ) + if "query_images" not in output_kwargs["text_kwargs"]: + warnings.warn( + "No `query_images` kwarg was detected. The use of `query_images` as an argument without specifying it explicitely as `query_images=` will be deprecated in future versions." + ) + # For backwards compatibility, we reuse `audio` as `query_images` in case + # downstream users passed it as a positional argument + if audio is not None: + output_kwargs["text_kwargs"]["query_images"] = audio + + query_images = output_kwargs["text_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + if text is None and query_images is None and images is None: raise ValueError( "You have to specify at least one text or query image or image. All three cannot be none." @@ -104,7 +153,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if text is not None: if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): - encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] elif isinstance(text, List) and isinstance(text[0], List): encodings = [] @@ -117,7 +166,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if len(t) != max_num_queries: t = t + [" "] * (max_num_queries - len(t)) - encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings") @@ -153,13 +202,11 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if query_images is not None: encoding = BatchEncoding() - query_pixel_values = self.image_processor( - query_images, return_tensors=return_tensors, **kwargs - ).pixel_values + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values encoding["query_pixel_values"] = query_pixel_values if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index a06913d7acf760..35eab8bdc14060 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -16,18 +16,36 @@ Processor class for VideoLlava. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack logger = logging.get_logger(__name__) +class VideoLlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + class VideoLlavaProcessor(ProcessorMixin): r""" Constructs a VideoLlava processor which wraps a VideoLlava image processor and a Llava tokenizer into a single processor. @@ -77,13 +95,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - videos: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + videos: Optional[ImageInput] = None, + audio=None, + **kwargs: Unpack[VideoLlavaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -105,26 +121,6 @@ def __call__( Video frames to preprocess. Expects a single or batch of video frames in NumPy array or PyTorch tensor. Each video should be of shape (T, C, H, W), where T is number of frames, C is number of channels, H and W are image height and width. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -135,9 +131,15 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + output_kwargs = self._merge_kwargs( + VideoLlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} if images is not None or videos is not None: - encoded_images = self.image_processor(images=images, videos=videos, return_tensors=return_tensors) + encoded_images = self.image_processor(images=images, videos=videos, **output_kwargs["images_kwargs"]) data.update(encoded_images) if isinstance(text, str): @@ -174,13 +176,7 @@ def __call__( sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) data.update(text_inputs) return BatchFeature(data=data)