diff --git a/docs/source/en/model_doc/align.md b/docs/source/en/model_doc/align.md index 5e41dac6024a20..0d34d95a798109 100644 --- a/docs/source/en/model_doc/align.md +++ b/docs/source/en/model_doc/align.md @@ -46,7 +46,7 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) candidate_labels = ["an image of a cat", "an image of a dog"] -inputs = processor(text=candidate_labels, images=image, return_tensors="pt") +inputs = processor(images=image ,text=candidate_labels, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) diff --git a/docs/source/en/model_doc/fuyu.md b/docs/source/en/model_doc/fuyu.md index a2e7be90aaf82a..bd55737da58ff8 100644 --- a/docs/source/en/model_doc/fuyu.md +++ b/docs/source/en/model_doc/fuyu.md @@ -18,16 +18,16 @@ rendered properly in your Markdown viewer. ## Overview -The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. +The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. -The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs. +The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs. By treating image tokens like text tokens and using a special image-newline character, the model knows when an image line ends. Image positional embeddings are removed. This avoids the need for different training phases for various image resolutions. With 8 billion parameters and licensed under CC-BY-NC, Fuyu-8B is notable for its ability to handle both text and images, its impressive context size of 16K, and its overall performance. The `Fuyu` models were trained using `bfloat16`, but the original inference uses `float16` The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be -used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. +used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be cast to the default `dtype` of `torch` (becomes `torch.float32`). Users should specify the `torch_dtype` they want, and if they don't it will be `torch.float32`. @@ -56,7 +56,7 @@ tar -xvf 8b_base_model_release.tar ``` Then, model can be loaded via: -```py +```py from transformers import FuyuConfig, FuyuForCausalLM model_config = FuyuConfig() model = FuyuForCausalLM(model_config).from_pretrained('/output/path') @@ -81,7 +81,7 @@ text_prompt = "Generate a coco-style caption.\\n" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) -inputs_to_model = processor(text=text_prompt, images=bus_image_pil) +inputs_to_model = processor(images=bus_image_pil, text=text_prompt) ``` @@ -90,7 +90,7 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap). The original code can be found [here](https://github.com/persimmon-ai-labs/adept-inference). - Fuyu uses a `sentencepiece` based tokenizer, with a `Unigram` model. It supports bytefallback, which is only available in `tokenizers==0.14.0` for the fast tokenizer. -The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. +The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. - The authors suggest to use the following prompt for image captioning: `f"Generate a coco-style caption.\\n"` diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index d0558be76467a2..1faeea67cf8743 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -133,7 +133,7 @@ import requests processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") -model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) +model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) model.to("cuda:0") # prepare image and text prompt, using the appropriate prompt template @@ -150,7 +150,7 @@ conversation = [ }, ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) -inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") +inputs = processor(image, prompt, return_tensors="pt").to("cuda:0") # autoregressively complete prompt output = model.generate(**inputs, max_new_tokens=100) @@ -222,7 +222,7 @@ prompts = [prompt_1, prompt_2] # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens -inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) +inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device) # Generate generate_ids = model.generate(**inputs, max_new_tokens=30) @@ -256,8 +256,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas from transformers import LlavaNextForConditionalGeneration model = LlavaNextForConditionalGeneration.from_pretrained( - model_id, - torch_dtype=torch.float16, + model_id, + torch_dtype=torch.float16, low_cpu_mem_usage=True, use_flash_attention_2=True ).to(0) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index dea035618a3341..f834aecf6932a3 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1575,7 +1575,7 @@ def forward( >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True ... ) >>> outputs = model(**inputs) diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index 7cfe14e52b44f9..792f614b10bea0 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -19,11 +19,7 @@ from typing import List, Union from ...image_utils import ImageInput -from ...processing_utils import ( - ProcessingKwargs, - ProcessorMixin, - Unpack, -) +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput @@ -76,8 +72,8 @@ def __init__(self, image_processor, tokenizer): def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, audio=None, videos=None, **kwargs: Unpack[AlignProcessorKwargs], @@ -90,13 +86,13 @@ def __call__( to the doctsring of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. @@ -114,6 +110,9 @@ def __call__( """ if text is None and images is None: raise ValueError("You must specify either text or images.") + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + output_kwargs = self._merge_kwargs( AlignProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 089313b03b7b60..88a7e1ff41c4d3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -264,7 +264,7 @@ def forward( >>> image = Image.open(requests.get(url, stream=True).raw) >>> prompt = "Generate a coco-style caption.\n" - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> outputs = model(**inputs) >>> generated_ids = model.generate(**inputs, max_new_tokens=7) diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 6b542ba3378e67..ff7d2c547dc44c 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -21,9 +21,10 @@ import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy -from ...utils import TensorType, is_torch_available, logging, requires_backends +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_available, logging, requires_backends if is_torch_available(): @@ -49,6 +50,24 @@ BEGINNING_OF_ANSWER_STRING = "<0x04>" # +class FuyuProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_attention_mask": True, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + def full_unpacked_stream_to_tensor( all_bi_tokens_to_place: List[int], full_unpacked_stream: List["torch.Tensor"], @@ -452,23 +471,11 @@ def get_sample_encoding( def __call__( self, - text=None, - images=None, - add_special_tokens: bool = True, - return_attention_mask: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + images: ImageInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[FuyuProcessorKwargs], ) -> "FuyuBatchFeature": """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -478,13 +485,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `List[PIL.Image.Image]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `List[PIL.Image.Image]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. Returns: [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields: @@ -498,31 +505,24 @@ def __call__( requires_backends(self, ["torch"]) # --- Check input validity --- - if not return_attention_mask: - raise ValueError("`return_attention_mask=False` is not supported for this model.") if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be None.") + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + FuyuProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True): + raise ValueError("`return_attention_mask=False` is not supported for this model.") + if text is not None and images is None: logger.warning("You are processing a text with no associated image. Make sure it is intended.") self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) return text_encoding if text is None and images is not None: @@ -537,7 +537,8 @@ def __call__( # --- Preprocess images using self.image_processor --- # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors - image_encoding = self.image_processor.preprocess(images, return_tensors="pt") + output_kwargs["images_kwargs"]["return_tensors"] = "pt" + image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"]) batch_images = image_encoding["images"] image_unpadded_heights = image_encoding["image_unpadded_heights"] image_unpadded_widths = image_encoding["image_unpadded_widths"] @@ -568,7 +569,7 @@ def __call__( ) all_encodings.append(sample_encoding) batch_encoding = self._left_pad_inputs_with_attention_mask( - model_inputs=all_encodings, return_attention_mask=return_attention_mask + model_inputs=all_encodings, return_attention_mask=True ) return FuyuBatchFeature(data=batch_encoding) diff --git a/src/transformers/models/grounding_dino/processing_grounding_dino.py b/src/transformers/models/grounding_dino/processing_grounding_dino.py index 00c183338be056..2b576992851884 100644 --- a/src/transformers/models/grounding_dino/processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/processing_grounding_dino.py @@ -17,20 +17,12 @@ """ import pathlib -import sys from typing import Dict, List, Optional, Tuple, Union from ...image_processing_utils import BatchFeature from ...image_transforms import center_to_corners_format from ...image_utils import AnnotationFormat, ImageInput -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin - - -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...utils import TensorType, is_torch_available diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index e3251395a78153..f6d35c1e6f7259 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -17,26 +17,41 @@ """ import os -from typing import List, Optional, Union +from typing import List, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( AddedToken, BatchEncoding, - PaddingStrategy, PreTokenizedInput, TextInput, - TruncationStrategy, ) -from ...utils import TensorType, logging +from ...utils import logging from ..auto import AutoTokenizer logger = logging.get_logger(__name__) +class InstructBlipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + class InstructBlipProcessor(ProcessorMixin): r""" Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single @@ -72,31 +87,33 @@ def __call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + audio=None, + videos=None, + **kwargs: Unpack[InstructBlipProcessorKwargs], ) -> BatchFeature: """ This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). """ if images is None and text is None: raise ValueError("You have to specify at least images or text.") + output_kwargs = self._merge_kwargs( + InstructBlipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + encoding = BatchFeature() if text is not None: @@ -105,24 +122,7 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - _text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=None, # needed to concatenate below - **kwargs, - ) + _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) # if we know how many query tokens, expand text inside processor. We need this hacky manipulation # because BLIP expects image tokens to be at the beginning even before BOS token @@ -145,31 +145,17 @@ def __call__( ) # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) - encoding.update(text_encoding) - qformer_text_encoding = self.qformer_tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, + text_encoding = BatchEncoding( + text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors") ) + + encoding.update(text_encoding) + qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"]) encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") if images is not None: - image_encoding = self.image_processor(images, return_tensors=return_tensors) + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) encoding.update(image_encoding) return encoding diff --git a/src/transformers/models/kosmos2/processing_kosmos2.py b/src/transformers/models/kosmos2/processing_kosmos2.py index 7f54ac3b44bd26..76108789718b41 100644 --- a/src/transformers/models/kosmos2/processing_kosmos2.py +++ b/src/transformers/models/kosmos2/processing_kosmos2.py @@ -21,10 +21,9 @@ from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput, is_batched -from ...processing_utils import ProcessorMixin +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from ...tokenization_utils import AddedToken -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy -from ...utils import TensorType +from ...tokenization_utils_base import BatchEncoding, TextInput BboxInput = Union[ @@ -35,6 +34,37 @@ ] +class Kosmos2ImagesKwargs(ImagesKwargs, total=False): + bboxes: Optional[List[float]] + num_image_tokens: Optional[int] + first_image_token_id: Optional[int] + + +class Kosmos2TextKwargs(TextKwargs, total=False): + add_eos_token: Optional[bool] + + +class Kosmos2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Kosmos2TextKwargs + images_kwargs: Kosmos2ImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "verbose": True, + "add_eos_token": False, + }, + "images_kwargs": { + "num_image_tokens": 64, + }, + } + + class Kosmos2Processor(ProcessorMixin): r""" Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single @@ -56,7 +86,7 @@ class Kosmos2Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] valid_kwargs = ["num_patch_index_tokens"] image_processor_class = "CLIPImageProcessor" - tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") + tokenizer_class = "AutoTokenizer" def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs): tokenizer.return_token_type_ids = False @@ -107,20 +137,9 @@ def __call__( self, images: ImageInput = None, text: Union[TextInput, List[TextInput]] = None, - bboxes: BboxInput = None, - num_image_tokens: Optional[int] = 64, - first_image_token_id: Optional[int] = None, - add_special_tokens: bool = True, - add_eos_token: bool = False, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + audio=None, + videos=None, + **kwargs: Unpack[Kosmos2ProcessorKwargs], ) -> BatchFeature: """ This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and @@ -145,10 +164,25 @@ def __call__( if images is None and text is None: raise ValueError("You have to specify either images or text.") + output_kwargs = self._merge_kwargs( + Kosmos2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + bboxes = output_kwargs["images_kwargs"].pop("bboxes", None) + num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 64) + first_image_token_id = output_kwargs["images_kwargs"].pop("first_image_token_id", None) + add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False) + + add_special_tokens = output_kwargs["text_kwargs"]["add_special_tokens"] + padding = output_kwargs["text_kwargs"]["padding"] + return_tensors = output_kwargs["text_kwargs"].setdefault("return_tensors", None) + encoding = BatchFeature() if images is not None: - image_encoding = self.image_processor(images, return_tensors=return_tensors) + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) encoding.update(image_encoding) if text is not None: @@ -159,21 +193,18 @@ def __call__( text = f"{self.tokenizer.bos_token}{text}" elif isinstance(text, list): text = [f"{self.tokenizer.bos_token}{s}" for s in text] - - text_encoding = self.tokenizer( - text=text, - add_special_tokens=(add_special_tokens and add_eos_token), - padding=padding and images is None, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of if images is None else pad_to_multiple_of, - return_attention_mask=return_attention_mask, - verbose=verbose, - return_tensors=return_tensors if images is None else None, - **kwargs, + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token ) + output_kwargs["text_kwargs"]["padding"] = padding if images is None else False + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors if images is None else None + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) encoding.update(text_encoding) + output_kwargs["text_kwargs"]["add_special_tokens"] = add_special_tokens + output_kwargs["text_kwargs"]["padding"] = padding + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors + if text is not None and images is not None: # Use the id of the first token after if first_image_token_id is None: @@ -218,18 +249,12 @@ def __call__( ) _, min_len_not_padded = sorted_length[0] idx, _ = sorted_length[-1] - - text_encoding = self.tokenizer( - text=[text[idx]], - add_special_tokens=(add_special_tokens and add_eos_token), - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - return_tensors=None, - **kwargs, + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token ) + output_kwargs["text_kwargs"]["return_tensors"] = None + + text_encoding = self.tokenizer(text=[text[idx]], **output_kwargs["text_kwargs"]) max_len_padded = len(text_encoding.input_ids[0]) if min_len_not_padded != max_len_padded: diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 579e6d44c1435b..41118599ec93b7 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -715,7 +715,9 @@ def preprocess( image_patches = self.get_image_patches( image, image_grid_pinpoints, - size=(size["shortest_edge"], size["shortest_edge"]), + size=(size["shortest_edge"], size["shortest_edge"]) + if "shortest_edge" in size + else (min(size["height"], size["width"]), min(size["height"], size["width"])), patch_size=crop_size["height"], resample=resample, data_format=input_data_format, diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bf76921090b244..d89d3f3ecbfba9 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -762,7 +762,7 @@ def forward( >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_length=30) diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 2a2df041283ed3..ce11be6d6309a8 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -16,19 +16,30 @@ Processor class for LLaVa-NeXT. """ -from typing import List, Optional, Union +from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging logger = logging.get_logger(__name__) +class LlavaNextProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "do_pad": True, + }, + } + + class LlavaNextProcessor(ProcessorMixin): r""" Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor. @@ -74,13 +85,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - do_pad: Optional[bool] = True, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[LlavaNextProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -90,36 +99,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - do_pad (`bool`, *optional*, defaults to self.do_pad): - Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch - and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros. - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -130,8 +116,18 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + LlavaNextProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if images is not None: - image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} @@ -164,13 +160,7 @@ def __call__( prompt_strings.append(sample) prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index 269fa8c62fb205..de8c594f94c9f2 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -18,9 +18,34 @@ from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class Pix2StructImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + header_text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + + +class Pix2StructProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Pix2StructImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "max_patches": 2048, + }, + } class Pix2StructProcessor(ProcessorMixin): @@ -50,23 +75,10 @@ def __call__( self, images=None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - max_patches: Optional[int] = 2048, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: + audio=None, + videos=None, + **kwargs: Unpack[Pix2StructProcessorKwargs], + ) -> Union[BatchEncoding, BatchFeature]: """ This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and [`T5TokenizerFast.__call__`] to prepare text for the model. @@ -76,59 +88,27 @@ def __call__( if images is None and text is None: raise ValueError("You have to specify either images or text.") + output_kwargs = self._merge_kwargs( + Pix2StructProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) # Get only text if images is None and not self.image_processor.is_vqa: self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) return text_encoding if not self.image_processor.is_vqa: # add pixel_values - encoding_image_processor = self.image_processor( - images, return_tensors=return_tensors, max_patches=max_patches, **kwargs - ) + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) else: # add pixel_values and bbox - encoding_image_processor = self.image_processor( - images, return_tensors=return_tensors, max_patches=max_patches, header_text=text, **kwargs - ) + output_kwargs["images_kwargs"].setdefault("header_text", text) + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None and not self.image_processor.is_vqa: - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) if "attention_mask" in text_encoding: text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 35000db677d387..ddeb585a757d5d 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -626,7 +626,7 @@ def test_inference(self): image = prepare_img() texts = ["a photo of a cat", "a photo of a dog"] - inputs = processor(text=texts, images=image, return_tensors="pt").to(torch_device) + inputs = processor(images=image, text=texts, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 6065251c5bb92a..9425bddb6f703c 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -330,7 +330,7 @@ def test_greedy_generation(self): text_prompt_coco_captioning = "Generate a coco-style caption.\n" - inputs = processor(text=text_prompt_coco_captioning, images=image, return_tensors="pt") + inputs = processor(images=image, text=text_prompt_coco_captioning, return_tensors="pt") generated_ids = model.generate(**inputs, max_new_tokens=10) # take the last 8 tokens (in order to skip special \n\x04 characters) and decode them diff --git a/tests/models/fuyu/test_processing_fuyu.py b/tests/models/fuyu/test_processing_fuyu.py index 459386952c3ed9..69a1d53e86f766 100644 --- a/tests/models/fuyu/test_processing_fuyu.py +++ b/tests/models/fuyu/test_processing_fuyu.py @@ -1,17 +1,25 @@ import io +import tempfile import unittest import requests -from transformers import AutoTokenizer, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow +from transformers import ( + AutoProcessor, + AutoTokenizer, + FuyuImageProcessor, + FuyuProcessor, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import require_torch, require_vision + +from ...test_processing_common import ProcessorTesterMixin if is_vision_available(): from PIL import Image -if is_vision_available() and is_torch_available(): - from transformers import FuyuImageProcessor, FuyuProcessor if is_torch_available(): import torch @@ -20,21 +28,36 @@ @require_torch -@require_torch_gpu -@slow -class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here? - """ """ +@require_vision +class FuyuProcessingTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = FuyuProcessor def setUp(self): - pretrained_model_name = "adept/fuyu-8b" - self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) - self.image_processor = FuyuImageProcessor() + self.tmpdirname = tempfile.mkdtemp() + + image_processor = FuyuImageProcessor() + tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b") + + processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) - self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer) self.text_prompt = "Generate a coco-style caption.\\n" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) + def get_processor(self): + image_processor = FuyuImageProcessor() + tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b") + processor = FuyuProcessor(image_processor, tokenizer, **self.prepare_processor_dict()) + + return processor + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + def test_fuyu_processing(self): """ Test to ensure that the standard processing on a gold example matches adept's code. @@ -43,7 +66,7 @@ def test_fuyu_processing(self): EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64) EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64) - one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil) + one_image_bus_model_inputs = self.get_processor()(text=self.text_prompt, images=self.bus_image_pil) # fmt: on torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS) @@ -53,8 +76,8 @@ def test_fuyu_processing_no_image(self): """ Test to check processor works with just text input """ - processor_outputs = self.processor(text=self.text_prompt) - tokenizer_outputs = self.tokenizer(self.text_prompt) + processor_outputs = self.get_processor()(text=self.text_prompt) + tokenizer_outputs = self.get_tokenizer()(self.text_prompt) self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"]) def test_fuyu_processing_no_text(self): @@ -90,7 +113,7 @@ def test_fuyu_processing_no_text(self): ]).to(torch.int64) # fmt: on - processor_outputs = self.processor(images=self.bus_image_pil) + processor_outputs = self.get_processor()(images=self.bus_image_pil) self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all()) def test_fuyu_processing_multiple_image_sample(self): @@ -107,7 +130,7 @@ def test_fuyu_processing_multiple_image_sample(self): # Batch of two images - equally sized images = [self.bus_image_pil, self.bus_image_pil] - processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) + processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images) self.assertTrue( ( @@ -124,18 +147,18 @@ def test_fuyu_processing_multiple_image_sample(self): # Processes single images with different sizes as expected images = [self.bus_image_pil] - processor_outputs = self.processor(text=self.text_prompt, images=images) + processor_outputs = self.get_processor()(text=self.text_prompt, images=images) self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all()) self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all()) images = [self.bus_image_pil.resize((64, 300))] - processor_outputs = self.processor(text=self.text_prompt, images=images) + processor_outputs = self.get_processor()(text=self.text_prompt, images=images) self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all()) self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all()) # Batch of two images - different sizes. Left-pads the smaller image inputs images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))] - processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) + processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images) padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1] padded_single_resized_image_patch = torch.cat( @@ -156,6 +179,155 @@ def test_fuyu_processing_multiple_image_sample(self): self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all()) self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all()) + # Rewrite as Fuyu supports tokenizer kwargs only when image is None. + @require_vision + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + + inputs = processor( + text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" + ) + self.assertEqual(len(inputs["input_ids"][0]), 112) + + @unittest.skip("Fuyu processor does not support image_processor kwargs") + def test_image_processor_defaults_preserved_by_image_kwargs(self): + pass + + @unittest.skip("Fuyu processor does not support image_processor kwargs") + def test_kwargs_overrides_default_image_processor_kwargs(self): + pass + + # Rewrite as Fuyu supports tokenizer kwargs only when image is None. + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(len(inputs["input_ids"][0]), 117) + + # Rewrite as Fuyu image processor does not return pixel values + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + # Rewrite as Fuyu image processor does not return pixel values + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + # Rewrite as Fuyu supports tokenizer kwargs only when image is None. + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="max_length", + max_length=76, + ) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + # Rewrite as Fuyu supports tokenizer kwargs only when image is None. + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + # Fuyu uses tokenizer kwargs only when image is None. + image_input = None + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="longest", + max_length=76, + ) + + self.assertEqual(len(inputs["input_ids"][0]), 6) + @require_torch class TestImageTextProcessingUtils(unittest.TestCase): diff --git a/tests/models/instructblip/test_processor_instructblip.py b/tests/models/instructblip/test_processor_instructblip.py index e03e555fed0857..ffec4b01112c2f 100644 --- a/tests/models/instructblip/test_processor_instructblip.py +++ b/tests/models/instructblip/test_processor_instructblip.py @@ -17,7 +17,7 @@ import pytest -from transformers.testing_utils import require_torch, require_vision +from transformers.testing_utils import require_vision from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -179,261 +179,3 @@ def test_model_input_names(self): list(inputs.keys()), ["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"], ) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_vision - @require_torch - def test_tokenizer_defaults_preserved_by_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input, return_tensors="pt") - self.assertEqual(len(inputs["input_ids"][0]), 117) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_image_processor_defaults_preserved_by_image_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input) - self.assertEqual(len(inputs["pixel_values"][0][0]), 234) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_vision - @require_torch - def test_kwargs_overrides_default_tokenizer_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", padding="longest") - qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor( - text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" - ) - self.assertEqual(len(inputs["input_ids"][0]), 112) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_kwargs_overrides_default_image_processor_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input, size=[224, 224]) - self.assertEqual(len(inputs["pixel_values"][0][0]), 224) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_unstructured_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - inputs = processor( - text=input_str, - images=image_input, - return_tensors="pt", - size={"height": 214, "width": 214}, - padding="max_length", - max_length=76, - ) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_unstructured_kwargs_batched(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = ["lower newer", "upper older longer string"] - image_input = self.prepare_image_inputs() * 2 - inputs = processor( - text=input_str, - images=image_input, - return_tensors="pt", - size={"height": 214, "width": 214}, - padding="longest", - max_length=76, - ) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 6) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_doubly_passed_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = ["lower newer"] - image_input = self.prepare_image_inputs() - with self.assertRaises(ValueError): - _ = processor( - text=input_str, - images=image_input, - images_kwargs={"size": {"height": 222, "width": 222}}, - size={"height": 214, "width": 214}, - ) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_structured_kwargs_nested(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - # Define the kwargs for each modality - all_kwargs = { - "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, - "text_kwargs": {"padding": "max_length", "max_length": 76}, - } - - inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.skip_processor_without_typed_kwargs(processor) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) - - # Override as InstructBlipProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_structured_kwargs_nested_from_dict(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - # Define the kwargs for each modality - all_kwargs = { - "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, - "text_kwargs": {"padding": "max_length", "max_length": 76}, - } - - inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) - - def test_overlapping_text_kwargs_handling(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - processor_kwargs = {} - processor_kwargs["image_processor"] = self.get_component("image_processor") - processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") - if not tokenizer.pad_token: - tokenizer.pad_token = "[TEST_PAD]" - if "video_processor" in self.processor_class.attributes: - processor_kwargs["video_processor"] = self.get_component("video_processor") - - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - with self.assertRaises(ValueError): - _ = processor( - text=input_str, - images=image_input, - return_tensors="pt", - padding="max_length", - text_kwargs={"padding": "do_not_pad"}, - ) diff --git a/tests/models/instructblipvideo/test_processor_instructblipvideo.py b/tests/models/instructblipvideo/test_processor_instructblipvideo.py index 8b29c771759217..d613d878223213 100644 --- a/tests/models/instructblipvideo/test_processor_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_processor_instructblipvideo.py @@ -15,18 +15,15 @@ import tempfile import unittest -import numpy as np import pytest -from transformers.testing_utils import require_torch, require_vision +from transformers.testing_utils import require_vision from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin if is_vision_available(): - from PIL import Image - from transformers import ( AutoProcessor, BertTokenizerFast, @@ -65,16 +62,6 @@ def get_qformer_tokenizer(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) - # Ignore copy - def prepare_image_inputs(self): - """This function prepares a list of list of PIL images""" - - video_inputs = [ - [Image.fromarray(np.random.randint(255, size=(30, 400, 3), dtype=np.uint8)) for _ in range(5)] - for _ in range(2) - ] - return video_inputs - def test_save_load_pretrained_additional_features(self): processor = InstructBlipVideoProcessor( tokenizer=self.get_tokenizer(), @@ -193,261 +180,3 @@ def test_model_input_names(self): list(inputs.keys()), ["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"], ) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_vision - @require_torch - def test_tokenizer_defaults_preserved_by_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input, return_tensors="pt") - self.assertEqual(len(inputs["input_ids"][0]), 117) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_image_processor_defaults_preserved_by_image_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input) - self.assertEqual(len(inputs["pixel_values"][0][0]), 234) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_vision - @require_torch - def test_kwargs_overrides_default_tokenizer_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", padding="longest") - qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor( - text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" - ) - self.assertEqual(len(inputs["input_ids"][0]), 112) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_kwargs_overrides_default_image_processor_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) - tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") - qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - inputs = processor(text=input_str, images=image_input, size=[224, 224]) - self.assertEqual(len(inputs["pixel_values"][0][0]), 224) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_unstructured_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - inputs = processor( - text=input_str, - images=image_input, - return_tensors="pt", - size={"height": 214, "width": 214}, - padding="max_length", - max_length=76, - ) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_unstructured_kwargs_batched(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = ["lower newer", "upper older longer string"] - image_input = self.prepare_image_inputs() * 2 - inputs = processor( - text=input_str, - images=image_input, - return_tensors="pt", - size={"height": 214, "width": 214}, - padding="longest", - max_length=76, - ) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 6) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_doubly_passed_kwargs(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = ["lower newer"] - image_input = self.prepare_image_inputs() - with self.assertRaises(ValueError): - _ = processor( - text=input_str, - images=image_input, - images_kwargs={"size": {"height": 222, "width": 222}}, - size={"height": 214, "width": 214}, - ) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_structured_kwargs_nested(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - # Define the kwargs for each modality - all_kwargs = { - "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, - "text_kwargs": {"padding": "max_length", "max_length": 76}, - } - - inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.skip_processor_without_typed_kwargs(processor) - - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) - - # Override as InstructBlipVideoProcessor has qformer_tokenizer - @require_torch - @require_vision - def test_structured_kwargs_nested_from_dict(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - - image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer") - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - self.skip_processor_without_typed_kwargs(processor) - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - # Define the kwargs for each modality - all_kwargs = { - "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, - "text_kwargs": {"padding": "max_length", "max_length": 76}, - } - - inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - - self.assertEqual(len(inputs["input_ids"][0]), 76) - - def test_overlapping_text_kwargs_handling(self): - if "image_processor" not in self.processor_class.attributes: - self.skipTest(f"image_processor attribute not present in {self.processor_class}") - processor_kwargs = {} - processor_kwargs["image_processor"] = self.get_component("image_processor") - processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") - if not tokenizer.pad_token: - tokenizer.pad_token = "[TEST_PAD]" - if "video_processor" in self.processor_class.attributes: - processor_kwargs["video_processor"] = self.get_component("video_processor") - - qformer_tokenizer = self.get_component("qformer_tokenizer") - - processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer) - self.skip_processor_without_typed_kwargs(processor) - - input_str = "lower newer" - image_input = self.prepare_image_inputs() - - with self.assertRaises(ValueError): - _ = processor( - text=input_str, - images=image_input, - return_tensors="pt", - padding="max_length", - text_kwargs={"padding": "do_not_pad"}, - ) diff --git a/tests/models/kosmos2/test_processor_kosmos2.py b/tests/models/kosmos2/test_processor_kosmos2.py index e07ba5fc106b6c..8de398ade70c71 100644 --- a/tests/models/kosmos2/test_processor_kosmos2.py +++ b/tests/models/kosmos2/test_processor_kosmos2.py @@ -61,7 +61,7 @@ class Kosmos2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - image_processor = CLIPImageProcessor() + image_processor = CLIPImageProcessor(do_center_crop=False) # We have a SentencePiece fixture for testing slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB) @@ -487,3 +487,147 @@ def check(texts, bboxes, expected_input_ids): self.assertListEqual(outputs.input_ids.numpy().tolist()[-1], EXPECTED_IDS_BATCH[-1]) self.assertListEqual(outputs.attention_mask.numpy().tolist()[-1], EXPECTED_MASK_BATCH[-1]) self.assertListEqual(outputs.image_embeds_position_mask.numpy().tolist()[-1], EXPECTED_IMG_POS_MASK_BATCH[-1]) + + # Rewrite as Kosmos-2 supports custom padding only when image is None. + @require_vision + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + # set image input to None + image_input = None + + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_length=112, + padding="max_length", + ) + + self.assertEqual(len(inputs["input_ids"][0]), 112) + + # Rewrite to test only image_processor kwargs + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 214, "width": 214}}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + # Rewrite to test only image_processor kwargs + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 214, "width": 214}}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + # Rewrite as Kosmos-2 supports custom padding only when image is None. + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + # set image input to None + image_input = None + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(len(inputs["input_ids"][0]), 117) + + # Rewrite as Kosmos-2 supports custom padding only when image is None. + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + # set image input to None + image_input = None + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="max_length", + max_length=76, + ) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + # Rewrite as Kosmos-2 supports custom padding only when image is None. + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + # set image input to None + image_input = None + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(len(inputs["input_ids"][0]), 10) diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 772f19e13a4be5..a54aeab8a28252 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -338,7 +338,7 @@ def test_small_model_integration_test(self): load_in_4bit=True, ) - inputs = self.processor(self.prompt, self.image, return_tensors="pt") + inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt") # verify inputs against original implementation filepath = hf_hub_download( @@ -390,8 +390,8 @@ def test_small_model_integration_test_batch(self): cats_image = Image.open(requests.get(url, stream=True).raw) inputs = self.processor( - [self.prompt, self.prompt], images=[self.image, cats_image], + text=[self.prompt, self.prompt], return_tensors="pt", padding=True, ).to(torch_device) @@ -415,7 +415,7 @@ def test_small_model_integration_test_unk_token(self): ) prompt_with_unk = "[INST] \nWhat is shown in this image? [/INST]" - inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt") + inputs = self.processor(images=self.image, text=prompt_with_unk, return_tensors="pt") # verify single forward pass inputs = inputs.to(torch_device) @@ -445,7 +445,7 @@ def test_small_model_integration_test_batch_different_resolutions(self): lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) inputs = self.processor( - [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True ).to(torch_device) pixel_values = inputs["pixel_values"] @@ -498,10 +498,10 @@ def test_small_model_integration_test_batch_matches_single(self): lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) inputs_batched = self.processor( - [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True ).to(torch_device) - inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to( + inputs_single = self.processor(images=lowres_img, text=self.prompt, return_tensors="pt", padding=True).to( torch_device ) @@ -527,7 +527,7 @@ def test_padding_side_when_merging_inputs(self): lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) inputs_batched = self.processor( - [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True ).to(torch_device) # model is in eval mode by default so we should get pad on the left side @@ -607,13 +607,13 @@ def test_expansion_in_processing(self): # check processing with expansion of inputs processor.vision_feature_select_strategy = "default" processor.patch_size = 14 - inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356) # check processing without expansion of inputs (legacy behavior) processor.vision_feature_select_strategy = None processor.patch_size = None - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs.input_ids.shape[-1] == 17) # generate exactly 20 tokens diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 450034f4151dd0..45faa24526305c 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -18,7 +18,9 @@ import torch from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor -from transformers.testing_utils import require_vision +from transformers.testing_utils import ( + require_vision, +) from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin diff --git a/tests/models/pix2struct/test_processor_pix2struct.py b/tests/models/pix2struct/test_processor_pix2struct.py index 17b3298145f823..ac8d4822f1c09f 100644 --- a/tests/models/pix2struct/test_processor_pix2struct.py +++ b/tests/models/pix2struct/test_processor_pix2struct.py @@ -37,6 +37,8 @@ @require_torch class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = Pix2StructProcessor + text_input_name = "decoder_input_ids" + images_input_name = "flattened_patches" def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -180,3 +182,148 @@ def test_model_input_names(self): # For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"] self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"]) + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + # Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", max_patches=1024, patch_size={"height": 8, "width": 8}) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["flattened_patches"][0][0]), 194) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + # Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", max_patches=4096) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, max_patches=1024) + self.assertEqual(len(inputs["flattened_patches"][0]), 1024) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + # Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_patches=1024, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + self.assertEqual(len(inputs["decoder_input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + # Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_patches=1024, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["decoder_input_ids"][0]), 5) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + # Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"max_patches": 1024}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["decoder_input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + # Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"max_patches": 1024}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["decoder_input_ids"][0]), 76) diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index 29575b49367268..59c19eabcaf53b 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -18,9 +18,7 @@ import requests import torch -from transformers.testing_utils import ( - require_vision, -) +from transformers.testing_utils import require_vision from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index a51c1d200eb0aa..1b5eabec13c05d 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -65,7 +65,7 @@ def get_component(self, attribute, **kwargs): component_class = processor_class_from_name(component_class_name) component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa - if attribute == "tokenizer" and not component.pad_token: + if "tokenizer" in attribute and not component.pad_token: component.pad_token = "[TEST_PAD]" if component.pad_token_id is None: component.pad_token_id = 0 @@ -321,14 +321,8 @@ def test_structured_kwargs_nested_from_dict(self): def test_overlapping_text_kwargs_handling(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - processor_kwargs = {} - processor_kwargs["image_processor"] = self.get_component("image_processor") - processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") - if not tokenizer.pad_token: - tokenizer.pad_token = "[TEST_PAD]" - if "video_processor" in self.processor_class.attributes: - processor_kwargs["video_processor"] = self.get_component("video_processor") - processor = self.processor_class(**processor_kwargs) + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer"